-
Notifications
You must be signed in to change notification settings - Fork 8
[Feat] Add matmul function and Linear layers #52
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: refactor
Are you sure you want to change the base?
Conversation
RMLYC
commented
Nov 18, 2025
- Add matmul functions
- Add Linear layers
- Do not support bias
Summary of ChangesHello @RMLYC, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request significantly enhances the library's foundational capabilities by introducing a custom matrix multiplication ( Highlights
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces a matmul function and a Linear layer. The overall implementation is a good start, but there are several areas for improvement regarding correctness, clarity, and testing. My review focuses on ensuring input tensors are handled correctly in the custom matmul function, improving the API design of the new Linear layer for better readability, and strengthening the tests to ensure numerical correctness. Addressing these points will make the new components more robust and easier to use.
| def test_linear(M, N, K, dtype, tune=False): | ||
| linear_layer = Linear(M, N, K, dtype=dtype, tune=tune) | ||
| input = torch.randn(M, K, dtype=dtype, device='cuda', requires_grad=True) | ||
|
|
||
| output = linear_layer(input) | ||
|
|
||
| loss = output.sum() | ||
| loss.backward() | ||
|
|
||
| print("Output shape:", output.shape) | ||
| print("Gradient shape:", input.grad.shape) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This test only verifies that the Linear layer runs without errors and prints output shapes. It does not check for the numerical correctness of the output or the gradients. To ensure the layer is implemented correctly, you should compare its results against a reference implementation. I recommend creating a linear_benchmark class, similar to matmul_benchmark, and using benchmark.check_fn to validate the output and gradients against a torch reference (e.g., torch.matmul).
|
|
||
| @staticmethod | ||
| def forward(ctx, A, B, fwd_op, da_bwd_op, db_bwd_op): | ||
| O = fwd_op(A, B) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The input tensors A and B are passed directly to fwd_op, which likely calls a custom CUDA kernel. These kernels often require input tensors to be contiguous in memory to function correctly. To prevent potential errors or incorrect computations with non-contiguous inputs, you should explicitly make them contiguous before the forward operation.
| O = fwd_op(A, B) | |
| A = A.contiguous() | |
| B = B.contiguous() | |
| O = fwd_op(A, B) |
| def __init__( | ||
| self, | ||
| M: int, | ||
| N: int, | ||
| K: int, | ||
| device='cuda', | ||
| dtype=torch.float16, | ||
| tune=False, | ||
| ): | ||
| super().__init__() | ||
| factory_kwargs = {"device": device, "dtype": dtype} | ||
| self.weight = nn.Parameter( | ||
| torch.empty((K, N), **factory_kwargs) | ||
| ) | ||
| self.fn = matmul( | ||
| M, | ||
| N, | ||
| K, | ||
| dtype=self.weight.dtype, | ||
| tune=tune, | ||
| ) | ||
| self.reset_parameters() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The parameter names M, N, K in the __init__ method are ambiguous and don't follow common PyTorch conventions. Based on the implementation, M is the batch size, N is the number of output features, and K is the number of input features. Renaming them to batch_size, out_features, and in_features respectively would significantly improve code readability and maintainability.
def __init__(
self,
batch_size: int,
out_features: int,
in_features: int,
device='cuda',
dtype=torch.float16,
tune=False,
):
super().__init__()
factory_kwargs = {"device": device, "dtype": dtype}
self.weight = nn.Parameter(
torch.empty((in_features, out_features), **factory_kwargs)
)
self.fn = matmul(
batch_size,
out_features,
in_features,
dtype=self.weight.dtype,
tune=tune,
)
self.reset_parameters()| def forward(self, input: torch.Tensor) -> torch.Tensor: | ||
| return self.fn(input, self.weight) No newline at end of file |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The parameter name input shadows the Python built-in function input(). It's a best practice to avoid using names of built-ins as variable names to prevent confusion and potential bugs. Consider renaming it to x or input_tensor.
| def forward(self, input: torch.Tensor) -> torch.Tensor: | |
| return self.fn(input, self.weight) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| return self.fn(x, self.weight) |
7d89d3c to
3da459b
Compare