Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 39 additions & 1 deletion benchmarks/gemm/gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,13 @@ class gemm_benchmark(Benchmark):

op_type = Gemm

def __init__(self, M, N, K, dtype):
def __init__(self, M, N, K, dtype, trans_A=False, trans_B=False):
self.M = M
self.N = N
self.K = K
self.dtype = dtype
self.trans_A = trans_A
self.trans_B = trans_B

@property
def total_flops(self):
Expand All @@ -27,4 +29,40 @@ def gen_inputs(self):
return A, B

def ref_program(self, A: torch.Tensor, B: torch.Tensor):
if self.trans_A:
A = A.T
if self.trans_B:
B = B.T
return torch.matmul(A, B)


class matmul_benchmark(Benchmark):

def __init__(self, M, N, K, dtype, grad=True):
self.M = M
self.N = N
self.K = K
self.dtype = dtype
self.grad = grad

@property
def total_flops(self):
return 6.0 * self.M * self.N * self.K

@property
def total_memory(self):
return 3 * (self.M * self.K + self.K * self.N + self.M * self.N) * self.dtype.itemsize

def gen_inputs(self):
A = torch.randn(self.M, self.K, device='cuda', dtype=self.dtype, requires_grad=self.grad)
B = torch.randn(self.K, self.N, device='cuda', dtype=self.dtype, requires_grad=self.grad)
return A, B

def ref_program(self, A: torch.Tensor, B: torch.Tensor):
output = torch.matmul(A, B)
if not self.grad:
return output
else:
loss = output.sum()
loss.backward()
return output, A.grad, B.grad
25 changes: 25 additions & 0 deletions tests/functions/test_matmul.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
import argparse
from top.functions import matmul
from top.utils import str2dtype
from benchmarks import matmul_benchmark


def test_matmul(M, N, K, dtype, tune=False):
fn = matmul(M, N, K, dtype, tune=tune)
benchmark = matmul_benchmark(M, N, K, dtype)

inputs = benchmark.gen_inputs()
benchmark.check_fn(fn, *inputs)


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--M', type=int, default=1024, help='M')
parser.add_argument('--N', type=int, default=1024, help='N')
parser.add_argument('--K', type=int, default=1024, help='K')
parser.add_argument(
'--dtype', type=str, default='float16', choices=['float16', 'bfloat16'], help='data type')
parser.add_argument('--tune', action='store_true', default=False, help='enable autotune')
args = parser.parse_args()

test_matmul(args.M, args.N, args.K, str2dtype[args.dtype], args.tune)
30 changes: 30 additions & 0 deletions tests/layers/test_linear.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
import argparse
import torch
from top.layers import Linear
from top.utils import str2dtype


def test_linear(M, N, K, dtype, tune=False):
linear_layer = Linear(M, N, K, dtype=dtype, tune=tune)
input = torch.randn(M, K, dtype=dtype, device='cuda', requires_grad=True)

output = linear_layer(input)

loss = output.sum()
loss.backward()

print("Output shape:", output.shape)
print("Gradient shape:", input.grad.shape)


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--M', type=int, default=1024, help='M')
parser.add_argument('--N', type=int, default=1024, help='N')
parser.add_argument('--K', type=int, default=1024, help='K')
parser.add_argument(
'--dtype', type=str, default='float16', choices=['float16', 'bfloat16'], help='data type')
parser.add_argument('--tune', action='store_true', default=False, help='enable autotune')
args = parser.parse_args()

test_linear(args.M, args.N, args.K, str2dtype[args.dtype], args.tune)
10 changes: 6 additions & 4 deletions tests/ops/test_gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@
from benchmarks import gemm_benchmark


def test_gemm(M, N, K, dtype, tune=False):
op = Gemm(M, N, K, dtype, tune=tune)
benchmark = gemm_benchmark(M, N, K, dtype)
def test_gemm(M, N, K, dtype, trans_A=False, trans_B=False, tune=False):
op = Gemm(M, N, K, trans_A=trans_A, trans_B=trans_B, dtype=dtype, tune=tune)
benchmark = gemm_benchmark(M, N, K, dtype, trans_A=trans_A, trans_B=trans_B)

inputs = benchmark.gen_inputs()
benchmark.check(op, *inputs)
Expand All @@ -20,7 +20,9 @@ def test_gemm(M, N, K, dtype, tune=False):
parser.add_argument('--K', type=int, default=1024, help='K')
parser.add_argument(
'--dtype', type=str, default='float16', choices=['float16', 'bfloat16'], help='data type')
parser.add_argument('--trans_A', action='store_true', default=False, help='transpose input A')
parser.add_argument('--trans_B', action='store_true', default=False, help='transpose input B')
parser.add_argument('--tune', action='store_true', default=False, help='enable autotune')
args = parser.parse_args()

test_gemm(args.M, args.N, args.K, str2dtype[args.dtype], args.tune)
test_gemm(args.M, args.N, args.K, str2dtype[args.dtype], args.trans_A, args.trans_B, args.tune)
3 changes: 2 additions & 1 deletion top/functions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,5 @@
from .mha_decode import *
from .gqa_decode import *
from .mla_decode import *
from .sparse_mla import *
from .sparse_mla import *
from .matmul import *
46 changes: 46 additions & 0 deletions top/functions/matmul.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
import torch
from .function import Function
from top.ops.gemm import Gemm

__all__ = ['matmul']


class gemm_ctx(torch.autograd.Function):

@staticmethod
def forward(ctx, A, B, fwd_op, da_bwd_op, db_bwd_op):
O = fwd_op(A, B)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The input tensors A and B are passed directly to fwd_op, which likely calls a custom CUDA kernel. These kernels often require input tensors to be contiguous in memory to function correctly. To prevent potential errors or incorrect computations with non-contiguous inputs, you should explicitly make them contiguous before the forward operation.

Suggested change
O = fwd_op(A, B)
A = A.contiguous()
B = B.contiguous()
O = fwd_op(A, B)


ctx.save_for_backward(A, B)
ctx.da_bwd_op = da_bwd_op
ctx.db_bwd_op = db_bwd_op

return O

@staticmethod
def backward(ctx, dO):
A, B = ctx.saved_tensors

dO = dO.contiguous()
dA = ctx.da_bwd_op(dO, B)
dB = ctx.db_bwd_op(A, dO)

return dA, dB, None, None, None


class matmul(Function):

def __init__(
self,
M: int,
N: int,
K: int,
dtype=torch.float16,
tune=False,
):
self.fwd_op = Gemm(M, N, K, dtype=dtype, tune=tune)
self.da_bwd_op = Gemm(M, K, N, dtype=dtype, trans_B=False, tune=tune)
self.db_bwd_op = Gemm(K, N, M, dtype=dtype, trans_A=False, tune=tune)

def forward(self, A: torch.Tensor, B: torch.Tensor) -> torch.Tensor:
return gemm_ctx.apply(A, B, self.fwd_op, self.da_bwd_op, self.db_bwd_op)
44 changes: 34 additions & 10 deletions top/kernels/gemm/gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,22 +7,27 @@
from top.kernels import Kernel


def _gemm_kernel(M, N, K, dtype='float16'):
def _gemm_kernel(M, N, K, trans_A, trans_B, dtype='float16'):
accum_dtype = "float"

@tilelang.jit(out_idx=[-1], compile_flags=["-O3", "-DENABLE_BF16"])
def _gemm_func(block_M, block_N, block_K, threads, num_stages, enable_rasteration):

A_shape = (K, M) if trans_A else (M, K)
B_shape = (N, K) if trans_B else (K, N)
A_shared_shape = (block_K, block_M) if trans_A else (block_M, block_K)
B_shared_shape = (block_N, block_K) if trans_B else (block_K, block_N)

@T.prim_func
def _gemm_main(
A: T.Tensor((M, K), dtype), # type: ignore
B: T.Tensor((K, N), dtype), # type: ignore
A: T.Tensor(A_shape, dtype), # type: ignore
B: T.Tensor(B_shape, dtype), # type: ignore
C: T.Tensor((M, N), dtype), # type: ignore
):
with T.Kernel(
T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
A_shared = T.alloc_shared((block_M, block_K), dtype)
B_shared = T.alloc_shared((block_K, block_N), dtype)
A_shared = T.alloc_shared(A_shared_shape, dtype)
B_shared = T.alloc_shared(B_shared_shape, dtype)
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
C_shared = T.alloc_shared((block_M, block_N), dtype)

Expand All @@ -34,9 +39,20 @@ def _gemm_main(
T.clear(C_local)

for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):
T.copy(A[by * block_M, k * block_K], A_shared)
T.copy(B[k * block_K, bx * block_N], B_shared)
T.gemm(A_shared, B_shared, C_local)
if not trans_A:
# A: (M, K)
T.copy(A[by * block_M, k * block_K], A_shared) # [block_M, block_K]
else:
# A: (K, M)
T.copy(A[k * block_K, by * block_M], A_shared) # [block_K, block_M]

if not trans_B:
# B: (K, N)
T.copy(B[k * block_K, bx * block_N], B_shared) # [block_K, block_N]
else:
# B: (N, K)
T.copy(B[bx * block_N, k * block_K], B_shared) # [block_N, block_K]
T.gemm(A_shared, B_shared, C_local, trans_A, trans_B)

T.copy(C_local, C_shared)
T.copy(C_shared, C[by * block_M, bx * block_N])
Expand All @@ -49,14 +65,22 @@ def _gemm_main(
class gemm_kernel(Kernel):
supported_archs: list[int] = [80, 89, 90]

def __init__(self, M, N, K, dtype, config: Optional[dict] = None, tune=False):
def __init__(self,
M,
N,
K,
dtype,
config: Optional[dict] = None,
tune=False,
trans_A=False,
trans_B=False):
super().__init__()
self.M = M
self.N = N
self.K = K
self.dtype = dtype

self.kernel = _gemm_kernel(M, N, K, self.dtype_str)
self.kernel = _gemm_kernel(M, N, K, trans_A, trans_B, self.dtype_str)

self.init_config(config, tune)

Expand Down
3 changes: 2 additions & 1 deletion top/layers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from .flash_attn import *
from .flash_decode import *
from .deepseek_mla import *
from .deepseek_mla import *
from .linear import *
34 changes: 34 additions & 0 deletions top/layers/linear.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
import math
import torch
from torch import nn
from top.functions import matmul


class Linear(nn.Module):

def __init__(
self,
batch_size: int,
out_features: int,
in_features: int,
device='cuda',
dtype=torch.float16,
tune=False,
):
super().__init__()
factory_kwargs = {"device": device, "dtype": dtype}
self.weight = nn.Parameter(torch.empty((in_features, out_features), **factory_kwargs))
self.fn = matmul(
batch_size,
out_features,
in_features,
dtype=self.weight.dtype,
tune=tune,
)
self.reset_parameters()

def reset_parameters(self):
nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))

def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.fn(x, self.weight)
5 changes: 4 additions & 1 deletion top/ops/gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ def __init__(self,
M: int,
N: int,
K: int,
trans_A=False,
trans_B=False,
dtype=torch.float16,
kernel_map: Optional[Dict[str, Kernel]] = None,
tune=False):
Expand All @@ -22,7 +24,8 @@ def __init__(self,
self.dtype = dtype

self.dispatch_kernel(kernel_map)
self.kernel = self.kernel_map["gemm_kernel"](M, N, K, self.dtype, tune=tune)
self.kernel = self.kernel_map["gemm_kernel"](
M, N, K, self.dtype, tune=tune, trans_A=trans_A, trans_B=trans_B)

@property
def default_kernel_map(self):
Expand Down