-
Notifications
You must be signed in to change notification settings - Fork 108
Add cutlass python dsl executor for quack-kernels
#2719
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR adds support for the CUTLASS DSL executor (cutlass_dsl_ex) to Thunder, integrating the quack library for optimized operations like softmax, cross_entropy, layer_norm, and RMS norm on NVIDIA SM9.0/10.0 GPUs.
- Introduces a new
cutlass_dsl_exexecutor with quack operation implementations - Adds comprehensive test coverage for quack operations
- Adds benchmark suites for performance comparison against nvfuser and torch_compile
- Registers the new executor in Thunder's executor registry
Reviewed Changes
Copilot reviewed 5 out of 5 changed files in this pull request and generated 6 comments.
Show a summary per file
| File | Description |
|---|---|
| thunder/executors/cutlass_dsl_ex.py | New file implementing the cutlass_dsl executor with quack operations for softmax, cross_entropy, layer_norm, and RMS norm |
| thunder/extend/init.py | Registers cutlass_dsl_ex in the get_all_executors function |
| thunder/tests/test_extend.py | Updates test to include cutlass_dsl executor in the expected executors list |
| thunder/tests/test_cutlass_dsl_ex.py | New test file with comprehensive tests for quack operations |
| thunder/benchmarks/targets.py | Adds benchmark classes and test functions for quack operations |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| if requires_reshpae := a.ndim > 2: | ||
| a = a.view(-1, original_shape[-1]) | ||
| ret = softmax_fwd(a) | ||
| if requires_reshpae: |
Copilot
AI
Nov 6, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Corrected spelling of 'requires_reshpae' to 'requires_reshape'.
| if requires_reshpae := a.ndim > 2: | |
| a = a.view(-1, original_shape[-1]) | |
| ret = softmax_fwd(a) | |
| if requires_reshpae: | |
| if requires_reshape := a.ndim > 2: | |
| a = a.view(-1, original_shape[-1]) | |
| ret = softmax_fwd(a) | |
| if requires_reshape: |
| if requires_reshpae := a.ndim > 2: | ||
| a = a.view(-1, original_shape[-1]) | ||
| ret = softmax_fwd(a) | ||
| if requires_reshpae: |
Copilot
AI
Nov 6, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Corrected spelling of 'requires_reshpae' to 'requires_reshape'.
| if requires_reshpae := a.ndim > 2: | |
| a = a.view(-1, original_shape[-1]) | |
| ret = softmax_fwd(a) | |
| if requires_reshpae: | |
| if requires_reshape := a.ndim > 2: | |
| a = a.view(-1, original_shape[-1]) | |
| ret = softmax_fwd(a) | |
| if requires_reshape: |
| a.ndim != 2 | ||
| or a.dtype not in {dtypes.float16, dtypes.bfloat16, dtypes.float32} | ||
| and target.ndim == 1 | ||
| and target.dytpe in {dtypes.int32, dtypes.int64} |
Copilot
AI
Nov 6, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Corrected spelling of 'dytpe' to 'dtype'.
| and target.dytpe in {dtypes.int32, dtypes.int64} | |
| and target.dtype in {dtypes.int32, dtypes.int64} |
| def quack_softmax_backward_meta(g: TensorProxy, a: TensorProxy) -> TensorProxy: | ||
| return TensorProxy(like=g) | ||
|
|
||
| quack_softmax_backward = cutlass_dsl_ex.register_operator( |
Copilot
AI
Nov 6, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The global variable 'quack_softmax_backward' is not used.
| quack_softmax_backward = cutlass_dsl_ex.register_operator( | |
| cutlass_dsl_ex.register_operator( |
| return thunder.jit(fn, executors=[nvfuserex]) | ||
|
|
||
|
|
||
| class BaseBenchmarkForQuack(Benchmark, metaclass=UserFacingBenchmarkMeta): |
Copilot
AI
Nov 6, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This class does not call Benchmark.init during initialization. (BaseBenchmarkForQuack.init may be missing a call to a base class init)
| weight: TensorProxy | None = None, | ||
| bias: TensorProxy | None = None, | ||
| eps: Number = 1e-5, | ||
| ) -> bool: | ||
| if ( | ||
| a.dtype not in {dtypes.float16, dtypes.bfloat16, dtypes.float32} | ||
| or weight.ndim != 1 | ||
| or a.shape[-1] != weight.shape[0] | ||
| or weight.dtype not in {dtypes.float32} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can weight be None? In that case this would need to check before trying to access .ndim
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
good catch. will check it
thunder/executors/cutlass_dsl_ex.py
Outdated
|
|
||
| quack_version: LooseVersion | ||
| try: | ||
| import quack |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do we need to add this into requirements to install it?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd not think we should do so. Because pip install quack-kernels seems to install cuda python packages such as nvidia-cutlass-dsl and I don't know how to having requirements.txt install cuda python packages that respect users local environments
|
|
||
| expected = F.cross_entropy(ref_x, targets, reduction="none") | ||
| actual = jitted(x, targets, reduction="none") | ||
| torch.testing.assert_close(expected, actual) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It seems the backward is not tested
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've not managed to have backward work
Starting with quack's softmax Signed-off-by: Masaki Kozuki <[email protected]>
Signed-off-by: Masaki Kozuki <[email protected]>
Signed-off-by: Masaki Kozuki <[email protected]>
Signed-off-by: Masaki Kozuki <[email protected]>
Signed-off-by: Masaki Kozuki <[email protected]>
Signed-off-by: Masaki Kozuki <[email protected]>
Signed-off-by: Masaki Kozuki <[email protected]>
Signed-off-by: Masaki Kozuki <[email protected]>
Signed-off-by: Masaki Kozuki <[email protected]>
Signed-off-by: Masaki Kozuki <[email protected]>
it seems that quack's cross-entropy function upcasts inputs to fp32, thus updating test and meta function Signed-off-by: Masaki Kozuki <[email protected]>
Signed-off-by: Masaki Kozuki <[email protected]>
Signed-off-by: Masaki Kozuki <[email protected]>
Signed-off-by: Masaki Kozuki <[email protected]>
Signed-off-by: Masaki Kozuki <[email protected]>
Signed-off-by: Masaki Kozuki <[email protected]>
Signed-off-by: Masaki Kozuki <[email protected]>
Signed-off-by: Masaki Kozuki <[email protected]>
Signed-off-by: Masaki Kozuki <[email protected]>
for more information, see https://pre-commit.ci
Signed-off-by: Masaki Kozuki <[email protected]>
Signed-off-by: Masaki Kozuki <[email protected]>
What does this PR do?
As per title, this adds cutlass python dsl executor.
In this PR, the kernels defined in https://github.com/Dao-AILab/quack, except matmul, are registered. Also, backward is not integrated.