diff --git a/benchmarks/gemm/gemm.py b/benchmarks/gemm/gemm.py index f77ee76..5c6e60f 100644 --- a/benchmarks/gemm/gemm.py +++ b/benchmarks/gemm/gemm.py @@ -7,11 +7,13 @@ class gemm_benchmark(Benchmark): op_type = Gemm - def __init__(self, M, N, K, dtype): + def __init__(self, M, N, K, dtype, trans_A=False, trans_B=False): self.M = M self.N = N self.K = K self.dtype = dtype + self.trans_A = trans_A + self.trans_B = trans_B @property def total_flops(self): @@ -27,4 +29,40 @@ def gen_inputs(self): return A, B def ref_program(self, A: torch.Tensor, B: torch.Tensor): + if self.trans_A: + A = A.T + if self.trans_B: + B = B.T return torch.matmul(A, B) + + +class matmul_benchmark(Benchmark): + + def __init__(self, M, N, K, dtype, grad=True): + self.M = M + self.N = N + self.K = K + self.dtype = dtype + self.grad = grad + + @property + def total_flops(self): + return 6.0 * self.M * self.N * self.K + + @property + def total_memory(self): + return 3 * (self.M * self.K + self.K * self.N + self.M * self.N) * self.dtype.itemsize + + def gen_inputs(self): + A = torch.randn(self.M, self.K, device='cuda', dtype=self.dtype, requires_grad=self.grad) + B = torch.randn(self.K, self.N, device='cuda', dtype=self.dtype, requires_grad=self.grad) + return A, B + + def ref_program(self, A: torch.Tensor, B: torch.Tensor): + output = torch.matmul(A, B) + if not self.grad: + return output + else: + loss = output.sum() + loss.backward() + return output, A.grad, B.grad diff --git a/tests/functions/test_matmul.py b/tests/functions/test_matmul.py new file mode 100644 index 0000000..982fc0e --- /dev/null +++ b/tests/functions/test_matmul.py @@ -0,0 +1,25 @@ +import argparse +from top.functions import matmul +from top.utils import str2dtype +from benchmarks import matmul_benchmark + + +def test_matmul(M, N, K, dtype, tune=False): + fn = matmul(M, N, K, dtype, tune=tune) + benchmark = matmul_benchmark(M, N, K, dtype) + + inputs = benchmark.gen_inputs() + benchmark.check_fn(fn, *inputs) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument('--M', type=int, default=1024, help='M') + parser.add_argument('--N', type=int, default=1024, help='N') + parser.add_argument('--K', type=int, default=1024, help='K') + parser.add_argument( + '--dtype', type=str, default='float16', choices=['float16', 'bfloat16'], help='data type') + parser.add_argument('--tune', action='store_true', default=False, help='enable autotune') + args = parser.parse_args() + + test_matmul(args.M, args.N, args.K, str2dtype[args.dtype], args.tune) \ No newline at end of file diff --git a/tests/layers/test_linear.py b/tests/layers/test_linear.py new file mode 100644 index 0000000..02a651b --- /dev/null +++ b/tests/layers/test_linear.py @@ -0,0 +1,30 @@ +import argparse +import torch +from top.layers import Linear +from top.utils import str2dtype + + +def test_linear(M, N, K, dtype, tune=False): + linear_layer = Linear(M, N, K, dtype=dtype, tune=tune) + input = torch.randn(M, K, dtype=dtype, device='cuda', requires_grad=True) + + output = linear_layer(input) + + loss = output.sum() + loss.backward() + + print("Output shape:", output.shape) + print("Gradient shape:", input.grad.shape) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument('--M', type=int, default=1024, help='M') + parser.add_argument('--N', type=int, default=1024, help='N') + parser.add_argument('--K', type=int, default=1024, help='K') + parser.add_argument( + '--dtype', type=str, default='float16', choices=['float16', 'bfloat16'], help='data type') + parser.add_argument('--tune', action='store_true', default=False, help='enable autotune') + args = parser.parse_args() + + test_linear(args.M, args.N, args.K, str2dtype[args.dtype], args.tune) \ No newline at end of file diff --git a/tests/ops/test_gemm.py b/tests/ops/test_gemm.py index 9311a9a..aeca400 100644 --- a/tests/ops/test_gemm.py +++ b/tests/ops/test_gemm.py @@ -4,9 +4,9 @@ from benchmarks import gemm_benchmark -def test_gemm(M, N, K, dtype, tune=False): - op = Gemm(M, N, K, dtype, tune=tune) - benchmark = gemm_benchmark(M, N, K, dtype) +def test_gemm(M, N, K, dtype, trans_A=False, trans_B=False, tune=False): + op = Gemm(M, N, K, trans_A=trans_A, trans_B=trans_B, dtype=dtype, tune=tune) + benchmark = gemm_benchmark(M, N, K, dtype, trans_A=trans_A, trans_B=trans_B) inputs = benchmark.gen_inputs() benchmark.check(op, *inputs) @@ -20,7 +20,9 @@ def test_gemm(M, N, K, dtype, tune=False): parser.add_argument('--K', type=int, default=1024, help='K') parser.add_argument( '--dtype', type=str, default='float16', choices=['float16', 'bfloat16'], help='data type') + parser.add_argument('--trans_A', action='store_true', default=False, help='transpose input A') + parser.add_argument('--trans_B', action='store_true', default=False, help='transpose input B') parser.add_argument('--tune', action='store_true', default=False, help='enable autotune') args = parser.parse_args() - test_gemm(args.M, args.N, args.K, str2dtype[args.dtype], args.tune) \ No newline at end of file + test_gemm(args.M, args.N, args.K, str2dtype[args.dtype], args.trans_A, args.trans_B, args.tune) \ No newline at end of file diff --git a/top/functions/__init__.py b/top/functions/__init__.py index 5559644..cc01838 100644 --- a/top/functions/__init__.py +++ b/top/functions/__init__.py @@ -3,4 +3,5 @@ from .mha_decode import * from .gqa_decode import * from .mla_decode import * -from .sparse_mla import * \ No newline at end of file +from .sparse_mla import * +from .matmul import * \ No newline at end of file diff --git a/top/functions/matmul.py b/top/functions/matmul.py new file mode 100644 index 0000000..deadc03 --- /dev/null +++ b/top/functions/matmul.py @@ -0,0 +1,46 @@ +import torch +from .function import Function +from top.ops.gemm import Gemm + +__all__ = ['matmul'] + + +class gemm_ctx(torch.autograd.Function): + + @staticmethod + def forward(ctx, A, B, fwd_op, da_bwd_op, db_bwd_op): + O = fwd_op(A, B) + + ctx.save_for_backward(A, B) + ctx.da_bwd_op = da_bwd_op + ctx.db_bwd_op = db_bwd_op + + return O + + @staticmethod + def backward(ctx, dO): + A, B = ctx.saved_tensors + + dO = dO.contiguous() + dA = ctx.da_bwd_op(dO, B) + dB = ctx.db_bwd_op(A, dO) + + return dA, dB, None, None, None + + +class matmul(Function): + + def __init__( + self, + M: int, + N: int, + K: int, + dtype=torch.float16, + tune=False, + ): + self.fwd_op = Gemm(M, N, K, dtype=dtype, tune=tune) + self.da_bwd_op = Gemm(M, K, N, dtype=dtype, trans_B=False, tune=tune) + self.db_bwd_op = Gemm(K, N, M, dtype=dtype, trans_A=False, tune=tune) + + def forward(self, A: torch.Tensor, B: torch.Tensor) -> torch.Tensor: + return gemm_ctx.apply(A, B, self.fwd_op, self.da_bwd_op, self.db_bwd_op) diff --git a/top/kernels/gemm/gemm.py b/top/kernels/gemm/gemm.py index 0ef1dd4..1eea5e4 100644 --- a/top/kernels/gemm/gemm.py +++ b/top/kernels/gemm/gemm.py @@ -7,22 +7,27 @@ from top.kernels import Kernel -def _gemm_kernel(M, N, K, dtype='float16'): +def _gemm_kernel(M, N, K, trans_A, trans_B, dtype='float16'): accum_dtype = "float" @tilelang.jit(out_idx=[-1], compile_flags=["-O3", "-DENABLE_BF16"]) def _gemm_func(block_M, block_N, block_K, threads, num_stages, enable_rasteration): + A_shape = (K, M) if trans_A else (M, K) + B_shape = (N, K) if trans_B else (K, N) + A_shared_shape = (block_K, block_M) if trans_A else (block_M, block_K) + B_shared_shape = (block_N, block_K) if trans_B else (block_K, block_N) + @T.prim_func def _gemm_main( - A: T.Tensor((M, K), dtype), # type: ignore - B: T.Tensor((K, N), dtype), # type: ignore + A: T.Tensor(A_shape, dtype), # type: ignore + B: T.Tensor(B_shape, dtype), # type: ignore C: T.Tensor((M, N), dtype), # type: ignore ): with T.Kernel( T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): - A_shared = T.alloc_shared((block_M, block_K), dtype) - B_shared = T.alloc_shared((block_K, block_N), dtype) + A_shared = T.alloc_shared(A_shared_shape, dtype) + B_shared = T.alloc_shared(B_shared_shape, dtype) C_local = T.alloc_fragment((block_M, block_N), accum_dtype) C_shared = T.alloc_shared((block_M, block_N), dtype) @@ -34,9 +39,20 @@ def _gemm_main( T.clear(C_local) for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): - T.copy(A[by * block_M, k * block_K], A_shared) - T.copy(B[k * block_K, bx * block_N], B_shared) - T.gemm(A_shared, B_shared, C_local) + if not trans_A: + # A: (M, K) + T.copy(A[by * block_M, k * block_K], A_shared) # [block_M, block_K] + else: + # A: (K, M) + T.copy(A[k * block_K, by * block_M], A_shared) # [block_K, block_M] + + if not trans_B: + # B: (K, N) + T.copy(B[k * block_K, bx * block_N], B_shared) # [block_K, block_N] + else: + # B: (N, K) + T.copy(B[bx * block_N, k * block_K], B_shared) # [block_N, block_K] + T.gemm(A_shared, B_shared, C_local, trans_A, trans_B) T.copy(C_local, C_shared) T.copy(C_shared, C[by * block_M, bx * block_N]) @@ -49,14 +65,22 @@ def _gemm_main( class gemm_kernel(Kernel): supported_archs: list[int] = [80, 89, 90] - def __init__(self, M, N, K, dtype, config: Optional[dict] = None, tune=False): + def __init__(self, + M, + N, + K, + dtype, + config: Optional[dict] = None, + tune=False, + trans_A=False, + trans_B=False): super().__init__() self.M = M self.N = N self.K = K self.dtype = dtype - self.kernel = _gemm_kernel(M, N, K, self.dtype_str) + self.kernel = _gemm_kernel(M, N, K, trans_A, trans_B, self.dtype_str) self.init_config(config, tune) diff --git a/top/layers/__init__.py b/top/layers/__init__.py index a6832a5..7112d8b 100644 --- a/top/layers/__init__.py +++ b/top/layers/__init__.py @@ -1,3 +1,4 @@ from .flash_attn import * from .flash_decode import * -from .deepseek_mla import * \ No newline at end of file +from .deepseek_mla import * +from .linear import * \ No newline at end of file diff --git a/top/layers/linear.py b/top/layers/linear.py new file mode 100644 index 0000000..2ae510b --- /dev/null +++ b/top/layers/linear.py @@ -0,0 +1,34 @@ +import math +import torch +from torch import nn +from top.functions import matmul + + +class Linear(nn.Module): + + def __init__( + self, + batch_size: int, + out_features: int, + in_features: int, + device='cuda', + dtype=torch.float16, + tune=False, + ): + super().__init__() + factory_kwargs = {"device": device, "dtype": dtype} + self.weight = nn.Parameter(torch.empty((in_features, out_features), **factory_kwargs)) + self.fn = matmul( + batch_size, + out_features, + in_features, + dtype=self.weight.dtype, + tune=tune, + ) + self.reset_parameters() + + def reset_parameters(self): + nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5)) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.fn(x, self.weight) diff --git a/top/ops/gemm.py b/top/ops/gemm.py index 2cc1e01..7ed2678 100644 --- a/top/ops/gemm.py +++ b/top/ops/gemm.py @@ -12,6 +12,8 @@ def __init__(self, M: int, N: int, K: int, + trans_A=False, + trans_B=False, dtype=torch.float16, kernel_map: Optional[Dict[str, Kernel]] = None, tune=False): @@ -22,7 +24,8 @@ def __init__(self, self.dtype = dtype self.dispatch_kernel(kernel_map) - self.kernel = self.kernel_map["gemm_kernel"](M, N, K, self.dtype, tune=tune) + self.kernel = self.kernel_map["gemm_kernel"]( + M, N, K, self.dtype, tune=tune, trans_A=trans_A, trans_B=trans_B) @property def default_kernel_map(self):