Skip to content

Conversation

@RMLYC
Copy link
Collaborator

@RMLYC RMLYC commented Nov 18, 2025

  1. Add matmul functions
  2. Add Linear layers
  3. Do not support bias

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @RMLYC, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request significantly enhances the library's foundational capabilities by introducing a custom matrix multiplication (matmul) function and building a Linear neural network layer upon it. The primary goal is to provide essential components for deep learning models, ensuring proper autograd support for gradient computation. The changes expand the library's core numerical operations and layer definitions, with an initial focus on functionality without bias.

Highlights

  • Custom Matmul Function: A new matmul function has been implemented, complete with a torch.autograd.Function context (gemm_ctx) to manage both forward and backward passes. This function leverages top.ops.gemm.Gemm for its core matrix multiplication operations.
  • Linear Layer Implementation: A new Linear neural network layer has been added, which internally utilizes the newly introduced matmul function. This initial version of the Linear layer does not include support for bias terms.
  • Benchmarking and Testing: Dedicated benchmarking and testing infrastructure has been added. This includes a matmul_benchmark class for performance evaluation and new test files (test_matmul.py, test_linear.py) to validate the functionality of both the matmul function and the Linear layer.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

@RMLYC RMLYC requested a review from Rachmanino November 18, 2025 11:40
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a matmul function and a Linear layer. The overall implementation is a good start, but there are several areas for improvement regarding correctness, clarity, and testing. My review focuses on ensuring input tensors are handled correctly in the custom matmul function, improving the API design of the new Linear layer for better readability, and strengthening the tests to ensure numerical correctness. Addressing these points will make the new components more robust and easier to use.

Comment on lines +7 to +17
def test_linear(M, N, K, dtype, tune=False):
linear_layer = Linear(M, N, K, dtype=dtype, tune=tune)
input = torch.randn(M, K, dtype=dtype, device='cuda', requires_grad=True)

output = linear_layer(input)

loss = output.sum()
loss.backward()

print("Output shape:", output.shape)
print("Gradient shape:", input.grad.shape)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

This test only verifies that the Linear layer runs without errors and prints output shapes. It does not check for the numerical correctness of the output or the gradients. To ensure the layer is implemented correctly, you should compare its results against a reference implementation. I recommend creating a linear_benchmark class, similar to matmul_benchmark, and using benchmark.check_fn to validate the output and gradients against a torch reference (e.g., torch.matmul).


@staticmethod
def forward(ctx, A, B, fwd_op, da_bwd_op, db_bwd_op):
O = fwd_op(A, B)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The input tensors A and B are passed directly to fwd_op, which likely calls a custom CUDA kernel. These kernels often require input tensors to be contiguous in memory to function correctly. To prevent potential errors or incorrect computations with non-contiguous inputs, you should explicitly make them contiguous before the forward operation.

Suggested change
O = fwd_op(A, B)
A = A.contiguous()
B = B.contiguous()
O = fwd_op(A, B)

Comment on lines +9 to +30
def __init__(
self,
M: int,
N: int,
K: int,
device='cuda',
dtype=torch.float16,
tune=False,
):
super().__init__()
factory_kwargs = {"device": device, "dtype": dtype}
self.weight = nn.Parameter(
torch.empty((K, N), **factory_kwargs)
)
self.fn = matmul(
M,
N,
K,
dtype=self.weight.dtype,
tune=tune,
)
self.reset_parameters()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The parameter names M, N, K in the __init__ method are ambiguous and don't follow common PyTorch conventions. Based on the implementation, M is the batch size, N is the number of output features, and K is the number of input features. Renaming them to batch_size, out_features, and in_features respectively would significantly improve code readability and maintainability.

    def __init__(
            self,
            batch_size: int,
            out_features: int,
            in_features: int,
            device='cuda',
            dtype=torch.float16,
            tune=False,
    ):
        super().__init__()
        factory_kwargs = {"device": device, "dtype": dtype}
        self.weight = nn.Parameter(
            torch.empty((in_features, out_features), **factory_kwargs)
        )
        self.fn = matmul(
            batch_size,
            out_features,
            in_features,
            dtype=self.weight.dtype,
            tune=tune,
        )
        self.reset_parameters()

Comment on lines +35 to +36
def forward(self, input: torch.Tensor) -> torch.Tensor:
return self.fn(input, self.weight) No newline at end of file
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The parameter name input shadows the Python built-in function input(). It's a best practice to avoid using names of built-ins as variable names to prevent confusion and potential bugs. Consider renaming it to x or input_tensor.

Suggested change
def forward(self, input: torch.Tensor) -> torch.Tensor:
return self.fn(input, self.weight)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.fn(x, self.weight)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant