-
Notifications
You must be signed in to change notification settings - Fork 8
[Feat] Add matmul function and Linear layers #52
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
RMLYC
wants to merge
9
commits into
tile-ai:refactor
Choose a base branch
from
RMLYC:lyc/add_matmul
base: refactor
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
9 commits
Select commit
Hold shift + click to select a range
3ca215e
add matmul and function
RMLYC ea101cf
fix bug in matmul benchmarks
RMLYC 77c480b
add linear layer
RMLYC 3da459b
fix bug in linear
RMLYC e8bbd3b
add trans
RMLYC bd40d07
fix bug in gemm transpose
RMLYC b5bdd4f
add benchmark
RMLYC 34be55d
fix code format
RMLYC 6e30fd5
fix bug in linear
RMLYC File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Some comments aren't visible on the classic Files Changed page.
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,25 @@ | ||
| import argparse | ||
| from top.functions import matmul | ||
| from top.utils import str2dtype | ||
| from benchmarks import matmul_benchmark | ||
|
|
||
|
|
||
| def test_matmul(M, N, K, dtype, tune=False): | ||
| fn = matmul(M, N, K, dtype, tune=tune) | ||
| benchmark = matmul_benchmark(M, N, K, dtype) | ||
|
|
||
| inputs = benchmark.gen_inputs() | ||
| benchmark.check_fn(fn, *inputs) | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| parser = argparse.ArgumentParser() | ||
| parser.add_argument('--M', type=int, default=1024, help='M') | ||
| parser.add_argument('--N', type=int, default=1024, help='N') | ||
| parser.add_argument('--K', type=int, default=1024, help='K') | ||
| parser.add_argument( | ||
| '--dtype', type=str, default='float16', choices=['float16', 'bfloat16'], help='data type') | ||
| parser.add_argument('--tune', action='store_true', default=False, help='enable autotune') | ||
| args = parser.parse_args() | ||
|
|
||
| test_matmul(args.M, args.N, args.K, str2dtype[args.dtype], args.tune) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,30 @@ | ||
| import argparse | ||
| import torch | ||
| from top.layers import Linear | ||
| from top.utils import str2dtype | ||
|
|
||
|
|
||
| def test_linear(M, N, K, dtype, tune=False): | ||
| linear_layer = Linear(M, N, K, dtype=dtype, tune=tune) | ||
| input = torch.randn(M, K, dtype=dtype, device='cuda', requires_grad=True) | ||
|
|
||
| output = linear_layer(input) | ||
|
|
||
| loss = output.sum() | ||
| loss.backward() | ||
|
|
||
| print("Output shape:", output.shape) | ||
| print("Gradient shape:", input.grad.shape) | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| parser = argparse.ArgumentParser() | ||
| parser.add_argument('--M', type=int, default=1024, help='M') | ||
| parser.add_argument('--N', type=int, default=1024, help='N') | ||
| parser.add_argument('--K', type=int, default=1024, help='K') | ||
| parser.add_argument( | ||
| '--dtype', type=str, default='float16', choices=['float16', 'bfloat16'], help='data type') | ||
| parser.add_argument('--tune', action='store_true', default=False, help='enable autotune') | ||
| args = parser.parse_args() | ||
|
|
||
| test_linear(args.M, args.N, args.K, str2dtype[args.dtype], args.tune) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,46 @@ | ||
| import torch | ||
| from .function import Function | ||
| from top.ops.gemm import Gemm | ||
|
|
||
| __all__ = ['matmul'] | ||
|
|
||
|
|
||
| class gemm_ctx(torch.autograd.Function): | ||
|
|
||
| @staticmethod | ||
| def forward(ctx, A, B, fwd_op, da_bwd_op, db_bwd_op): | ||
| O = fwd_op(A, B) | ||
|
|
||
| ctx.save_for_backward(A, B) | ||
| ctx.da_bwd_op = da_bwd_op | ||
| ctx.db_bwd_op = db_bwd_op | ||
|
|
||
| return O | ||
|
|
||
| @staticmethod | ||
| def backward(ctx, dO): | ||
| A, B = ctx.saved_tensors | ||
|
|
||
| dO = dO.contiguous() | ||
| dA = ctx.da_bwd_op(dO, B) | ||
| dB = ctx.db_bwd_op(A, dO) | ||
|
|
||
| return dA, dB, None, None, None | ||
|
|
||
|
|
||
| class matmul(Function): | ||
|
|
||
| def __init__( | ||
| self, | ||
| M: int, | ||
| N: int, | ||
| K: int, | ||
| dtype=torch.float16, | ||
| tune=False, | ||
| ): | ||
| self.fwd_op = Gemm(M, N, K, dtype=dtype, tune=tune) | ||
| self.da_bwd_op = Gemm(M, K, N, dtype=dtype, trans_B=False, tune=tune) | ||
| self.db_bwd_op = Gemm(K, N, M, dtype=dtype, trans_A=False, tune=tune) | ||
|
|
||
| def forward(self, A: torch.Tensor, B: torch.Tensor) -> torch.Tensor: | ||
| return gemm_ctx.apply(A, B, self.fwd_op, self.da_bwd_op, self.db_bwd_op) | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,3 +1,4 @@ | ||
| from .flash_attn import * | ||
| from .flash_decode import * | ||
| from .deepseek_mla import * | ||
| from .deepseek_mla import * | ||
| from .linear import * |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,34 @@ | ||
| import math | ||
| import torch | ||
| from torch import nn | ||
| from top.functions import matmul | ||
|
|
||
|
|
||
| class Linear(nn.Module): | ||
|
|
||
| def __init__( | ||
| self, | ||
| batch_size: int, | ||
| out_features: int, | ||
| in_features: int, | ||
| device='cuda', | ||
| dtype=torch.float16, | ||
| tune=False, | ||
| ): | ||
| super().__init__() | ||
| factory_kwargs = {"device": device, "dtype": dtype} | ||
| self.weight = nn.Parameter(torch.empty((in_features, out_features), **factory_kwargs)) | ||
| self.fn = matmul( | ||
| batch_size, | ||
| out_features, | ||
| in_features, | ||
| dtype=self.weight.dtype, | ||
| tune=tune, | ||
| ) | ||
| self.reset_parameters() | ||
|
|
||
| def reset_parameters(self): | ||
| nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5)) | ||
|
|
||
| def forward(self, x: torch.Tensor) -> torch.Tensor: | ||
| return self.fn(x, self.weight) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The input tensors
AandBare passed directly tofwd_op, which likely calls a custom CUDA kernel. These kernels often require input tensors to be contiguous in memory to function correctly. To prevent potential errors or incorrect computations with non-contiguous inputs, you should explicitly make them contiguous before the forward operation.