Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions deep_gemm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
gemm_fp8_fp8_bf16_nt,
m_grouped_gemm_fp8_fp8_bf16_nt_contiguous,
m_grouped_gemm_fp8_fp8_bf16_nt_masked,
m_grouped_gemm_fp8_fp8_bf16_nt_offset,
wgrad_gemm_fp8_fp8_fp32_nt,
k_grouped_wgrad_gemm_fp8_fp8_fp32_nt,
ceil_div,
Expand Down
805 changes: 804 additions & 1 deletion deep_gemm/include/deep_gemm/fp8_gemm.cuh

Large diffs are not rendered by default.

24 changes: 24 additions & 0 deletions deep_gemm/include/deep_gemm/mma_utils.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,30 @@ struct SM90_U32x4_STSM_N {
}
};

template <typename dtype_t>
struct SM90_U32x2_STSM_T
{
__device__ __forceinline__ static void copy(dtype_t src_0, dtype_t src_1, void* smem_dst)
{
const uint32_t src[2] = {*reinterpret_cast<uint32_t*>(&src_0), *reinterpret_cast<uint32_t*>(&src_1)};
asm volatile("stmatrix.sync.aligned.x2.m8n8.shared.b16.trans [%0], {%1, %2};\n" ::"l"(smem_dst), "r"(src[0]),
"r"(src[1]));
}
};

template <typename dtype_t>
struct SM90_U32x4_STSM_T
{
__device__ __forceinline__ static void copy(
dtype_t src_0, dtype_t src_1, dtype_t src_2, dtype_t src_3, void* smem_dst)
{
const uint32_t src[4] = {*reinterpret_cast<uint32_t*>(&src_0), *reinterpret_cast<uint32_t*>(&src_1),
*reinterpret_cast<uint32_t*>(&src_2), *reinterpret_cast<uint32_t*>(&src_3)};
asm volatile("stmatrix.sync.aligned.x4.m8n8.shared.b16.trans [%0], {%1, %2, %3, %4};\n" ::"l"(smem_dst),
"r"(src[0]), "r"(src[1]), "r"(src[2]), "r"(src[3]));
}
};

__forceinline__ __device__ void warpgroup_arrive() {
asm volatile("wgmma.fence.sync.aligned;\n" ::: "memory");
}
Expand Down
263 changes: 262 additions & 1 deletion deep_gemm/include/deep_gemm/scheduler.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@ namespace deep_gemm {
enum class GemmType {
Normal,
GroupedContiguous,
GroupedMasked
GroupedMasked,
GroupedWithOffset
};

#pragma clang diagnostic push
Expand Down Expand Up @@ -158,6 +159,266 @@ struct Scheduler {
}
};


template <uint32_t kNumTMAMulticast, uint32_t kNumNBlocks, uint32_t kNumNBlocksPerGroup>
__device__ __forceinline__ void offset_get_swizzled_block_idx(
const uint32_t num_m_blocks, int block_idx, uint32_t& m_block_idx, uint32_t& n_block_idx)
{
DG_STATIC_ASSERT(kNumNBlocksPerGroup % kNumTMAMulticast == 0, "Invalid group size");

// Swizzle for better L2 usages
auto num_blocks_per_group = num_m_blocks * kNumNBlocksPerGroup;
auto group_idx = block_idx / num_blocks_per_group;
auto first_n_block_idx = group_idx * kNumNBlocksPerGroup;
auto num_n_blocks_in_group = min(kNumNBlocksPerGroup, kNumNBlocks - first_n_block_idx);
auto in_group_idx = block_idx % num_blocks_per_group;
m_block_idx = in_group_idx / num_n_blocks_in_group;
n_block_idx = first_n_block_idx + in_group_idx % num_n_blocks_in_group;
}



struct GroupedWithOffsetSchedulerInput
{
uint32_t shape_m;
int64_t* problem_m_offsets;
};

struct GroupedWithOffsetSchedulerInputSwapAB
{
uint32_t shape_m;
int64_t* problem_n_offsets;
};

struct StridedBatchedSchedulerInput
{
uint32_t shape_m;
uint64_t ld_a;
uint64_t stride_a;
uint64_t ld_b;
uint64_t stride_b;
uint64_t ld_d;
uint64_t stride_d;
};

struct StridedBatchedSchedulerInputSwapAB
{
uint32_t shape_n;
uint64_t ld_a;
uint64_t stride_a;
uint64_t ld_b;
uint64_t stride_b;
uint64_t ld_d;
uint64_t stride_d;
};


// Need to keep the same as the one in tests/unittest/_torch/thop/deep_gemm_tests.py
template <typename T_offset, typename T_index>
__host__ __device__ __forceinline__ T_offset compute_padded_offset(T_offset offset, T_index problem_idx)
{
// This formulation ensures that padded_offset[i + 1] - padded_offset[i] >= offset[i + 1] - offset[i].
constexpr T_offset alignment = 32;
return (offset + problem_idx * (alignment - 1)) / alignment * alignment;
}

template <uint32_t SHAPE_N, uint32_t BLOCK_M, uint32_t BLOCK_N, uint32_t kNumGroups, uint32_t kNumTMAMulticast,
uint32_t kNumNBlocks = ceil_div(SHAPE_N, BLOCK_N), uint32_t kNumNBlocksPerGroup = 16>
struct GroupedWithOffsetScheduler
{
static constexpr GemmType gemm_type = GemmType::GroupedWithOffset;

int current_iter = -1;
uint32_t curr_group_idx;
uint32_t curr_cumsum;
int64_t m_offset;
int64_t m_padded_4_offset;
int64_t m_boundary;
int64_t* problem_m_offsets;

using Input = GroupedWithOffsetSchedulerInput;
Input input;

GroupedWithOffsetScheduler() {}

__device__ __forceinline__ GroupedWithOffsetScheduler(Input& input)
{
this->problem_m_offsets = input.problem_m_offsets;
curr_group_idx = 0;
curr_cumsum = 0;
}

__device__ __forceinline__ uint32_t get_global_m_idx(uint32_t const& block_idx)
{
return m_offset + block_idx * BLOCK_M;
}

__device__ __forceinline__ uint32_t get_global_n_idx(
uint32_t const shape_dim, uint32_t const block_size, uint32_t const& block_idx, uint32_t const& m_block_idx = 0)
{
return curr_group_idx * shape_dim + block_idx * block_size;
}

__device__ __forceinline__ uint32_t get_global_scales_a_idx(uint32_t const& block_idx)
{
return m_padded_4_offset + block_idx * BLOCK_M;
}

__device__ __forceinline__ uint32_t get_global_scales_b_idx(
uint32_t const shape_dim, uint32_t const block_size, uint32_t const& block_idx, uint32_t const& m_block_idx = 0)
{
return curr_group_idx * shape_dim + block_idx * block_size;
}

__device__ __forceinline__ bool get_next_block(uint32_t& m_block_idx, uint32_t& n_block_idx)
{
++current_iter;
auto const next_block_idx = current_iter * gridDim.x + blockIdx.x;
uint32_t num_m_blocks;
while (true)
{
// End of the task
if (curr_group_idx == kNumGroups)
return false;
m_offset = __ldg(problem_m_offsets + curr_group_idx);
m_boundary = __ldg(problem_m_offsets + curr_group_idx + 1);
m_padded_4_offset = compute_padded_offset(m_offset, curr_group_idx);
auto m = m_boundary - m_offset;
// Within current group
num_m_blocks = ceil_div(m, static_cast<int64_t>(BLOCK_M));
auto current_m_block_cumsum = curr_cumsum + num_m_blocks;
if (next_block_idx < current_m_block_cumsum * kNumNBlocks)
break;

// Move to check the next group
curr_group_idx++;
curr_cumsum = current_m_block_cumsum;
}

offset_get_swizzled_block_idx<kNumTMAMulticast, kNumNBlocks, kNumNBlocksPerGroup>(
num_m_blocks, next_block_idx - curr_cumsum * kNumNBlocks, m_block_idx, n_block_idx);
return true;
}
};

template <uint32_t SHAPE_M, uint32_t BLOCK_M, uint32_t BLOCK_N, uint32_t kNumGroups, uint32_t kNumTMAMulticast,
uint32_t kNumMBlocks = ceil_div(SHAPE_M, BLOCK_M), uint32_t kNumMBlocksPerGroup = 16>
struct GroupedWithOffsetSchedulerSwapAB
{
static constexpr GemmType gemm_type = GemmType::GroupedWithOffset;

int current_iter = -1;
uint32_t curr_group_idx;
uint32_t curr_cumsum;
int64_t n_offset;
int64_t n_padded_4_offset;
int64_t n_boundary;
int64_t* problem_n_offsets;

using Input = GroupedWithOffsetSchedulerInputSwapAB;
Input input;

GroupedWithOffsetSchedulerSwapAB() {}

__device__ __forceinline__ GroupedWithOffsetSchedulerSwapAB(Input& input)
{
this->problem_n_offsets = input.problem_n_offsets;
curr_group_idx = 0;
curr_cumsum = 0;
}

// weight
__device__ __forceinline__ uint32_t get_global_m_idx(
const uint32_t shape_dim, const uint32_t block_size, uint32_t const& block_idx, uint32_t const& n_block_idx = 0)
{
return curr_group_idx * shape_dim + block_idx * block_size;
}

// act
__device__ __forceinline__ uint32_t get_global_n_idx(uint32_t const& block_idx)
{
return n_offset + block_idx * BLOCK_N;
}

// act scales
__device__ __forceinline__ uint32_t get_global_scales_b_idx(uint32_t const& block_idx)
{
return n_padded_4_offset + block_idx * BLOCK_N;
}

// weight scales
__device__ __forceinline__ uint32_t get_global_scales_a_idx(
const uint32_t shape_dim, const uint32_t block_size, uint32_t const& block_idx, uint32_t const& n_block_idx = 0)
{
return curr_group_idx * shape_dim + block_idx * block_size;
}

__device__ __forceinline__ bool get_next_block(uint32_t& m_block_idx, uint32_t& n_block_idx)
{
++current_iter;
auto const next_block_idx = current_iter * gridDim.x + blockIdx.x;
uint32_t num_n_blocks;
while (true)
{
// End of the task
if (curr_group_idx == kNumGroups)
return false;
n_offset = __ldg(problem_n_offsets + curr_group_idx);
n_boundary = __ldg(problem_n_offsets + curr_group_idx + 1);
n_padded_4_offset = compute_padded_offset(n_offset, curr_group_idx);
auto n = n_boundary - n_offset;
// Within current group
num_n_blocks = ceil_div(n, static_cast<int64_t>(BLOCK_N));
auto current_n_block_cumsum = curr_cumsum + num_n_blocks;
if (next_block_idx < current_n_block_cumsum * kNumMBlocks)
break;

// Move to check the next group
curr_group_idx++;
curr_cumsum = current_n_block_cumsum;
}

offset_get_swizzled_block_idx<kNumTMAMulticast, kNumMBlocks, kNumMBlocksPerGroup>(
num_n_blocks, next_block_idx - curr_cumsum * kNumMBlocks, n_block_idx, m_block_idx);
return true;
}
};

template <GemmType GT, uint32_t SHAPE_N, uint32_t SHAPE_K, uint32_t BLOCK_M, uint32_t BLOCK_N, uint32_t BLOCK_K,
uint32_t kNumGroups, uint32_t kNumTMAMulticast, uint32_t kNumNBlocks = ceil_div(SHAPE_N, BLOCK_N),
uint32_t kNumNBlocksPerGroup = 16>
struct SchedulerSelector
{
static constexpr auto select_type()
{
if constexpr (GT == GemmType::GroupedWithOffset)
return GroupedWithOffsetScheduler<SHAPE_N, BLOCK_M, BLOCK_N, kNumGroups, kNumTMAMulticast, kNumNBlocks,
kNumNBlocksPerGroup>();
}

using type = decltype(select_type());
};

template <GemmType GT, uint32_t SHAPE_M, uint32_t SHAPE_K, uint32_t BLOCK_M, uint32_t BLOCK_N, uint32_t BLOCK_K,
uint32_t kNumGroups, uint32_t kNumTMAMulticast, uint32_t kNumMBlocks = ceil_div(SHAPE_M, BLOCK_M),
uint32_t kNumMBlocksPerGroup = 16>
struct SchedulerSelectorSwapAB
{
static constexpr auto select_type()
{
static_assert(GT == GemmType::GroupedWithOffset || GT == GemmType::Normal,
"Only GroupedWithOffset and Normal are supported for SwapAB");
if constexpr (GT == GemmType::Normal)
return NormalSchedulerSwapAB<SHAPE_M, BLOCK_M, BLOCK_N, kNumGroups, kNumTMAMulticast, kNumMBlocks,
kNumMBlocksPerGroup>();
if constexpr (GT == GemmType::GroupedWithOffset)
return GroupedWithOffsetSchedulerSwapAB<SHAPE_M, BLOCK_M, BLOCK_N, kNumGroups, kNumTMAMulticast,
kNumMBlocks, kNumMBlocksPerGroup>();
}

using type = decltype(select_type());
};

#pragma clang diagnostic pop

} // namespace deep_gemm
3 changes: 2 additions & 1 deletion deep_gemm/jit_kernels/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from .gemm import gemm_fp8_fp8_bf16_nt
from .m_grouped_gemm import (
m_grouped_gemm_fp8_fp8_bf16_nt_contiguous,
m_grouped_gemm_fp8_fp8_bf16_nt_masked
m_grouped_gemm_fp8_fp8_bf16_nt_masked,
m_grouped_gemm_fp8_fp8_bf16_nt_offset
)
from .wgrad_gemm import (
wgrad_gemm_fp8_fp8_fp32_nt,
Expand Down
Loading