-
Notifications
You must be signed in to change notification settings - Fork 108
Add F.scaled_mm
#2720
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Add F.scaled_mm
#2720
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR adds support for torch.nn.functional.scaled_mm operation in Thunder. This operation performs scaled matrix multiplication, which is commonly used for FP8 quantized operations.
- Adds a new
scaled_mmfunction tothunder/torch/__init__.pywith input validation and shape inference logic - Registers the implementation in
thunder/executors/torchex.pyto delegate to PyTorch - Adds comprehensive test coverage with tensor-wise, row-wise, and block-wise scaling tests
Reviewed Changes
Copilot reviewed 3 out of 3 changed files in this pull request and generated 3 comments.
| File | Description |
|---|---|
| thunder/torch/init.py | Implements the scaled_mm symbol with parameter validation and output shape/dtype inference |
| thunder/executors/torchex.py | Registers the torch executor implementation for scaled_mm |
| thunder/tests/test_ops.py | Adds test helper functions and comprehensive tests for scaled_mm with different scaling strategies |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| ValueError, | ||
| ) | ||
| for enum_value in values: | ||
| _ = int(enum_value.value) if hasattr(enum_value, "value") else int(enum_value) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we send a more detailed error here instead of a TypeError?
thunder/torch/__init__.py
Outdated
| scale_recipe_a, | ||
| scale_b, | ||
| scale_recipe_b, | ||
| swizzle_a=None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It would be helpful to add type annotation for scale_a, scale_recipe_a, scale_b, scale_recipe_b, swizzle_a and swizzle_b.
| tensor_args.append(bias) | ||
| utils.check_same_device(*tensor_args) | ||
|
|
||
| result_dtype = to_dtype(output_dtype or torch.bfloat16) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is the or torch.bfloat16 needed for the case if user explicitly passes output_dtype=None?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it's still needed
thunder/tests/test_ops.py
Outdated
| if not hasattr(torch.nn.functional, "scaled_mm"): | ||
| pytest.skip("torch.nn.functional.scaled_mm is not available in this PyTorch build") | ||
| device = torch.device("cuda") | ||
| torch.manual_seed(0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It would be better not to set the seed so that we test with different values.
…wise scaling Signed-off-by: Masaki Kozuki <[email protected]>
Signed-off-by: Masaki Kozuki <[email protected]>
Signed-off-by: Masaki Kozuki <[email protected]>
Signed-off-by: Masaki Kozuki <[email protected]>
Signed-off-by: Masaki Kozuki <[email protected]>
Signed-off-by: Masaki Kozuki <[email protected]>
Signed-off-by: Masaki Kozuki <[email protected]>
Signed-off-by: Masaki Kozuki <[email protected]>
Signed-off-by: Masaki Kozuki <[email protected]>
Signed-off-by: Masaki Kozuki <[email protected]>
Signed-off-by: Masaki Kozuki <[email protected]>
Signed-off-by: Masaki Kozuki <[email protected]>
Signed-off-by: Masaki Kozuki <[email protected]>
e13728e to
9aae77b
Compare
Signed-off-by: Masaki Kozuki <[email protected]>
kshitij12345
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Overall looks good, just a few comments regarding the tests. Thanks!
Signed-off-by: Masaki Kozuki <[email protected]>
Signed-off-by: Masaki Kozuki <[email protected]>
Signed-off-by: Masaki Kozuki <[email protected]>
kshitij12345
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thanks @crcrpar
What does this PR do?
As per title, this PR adds
F.scaled_mmtothunder.torchand cover it with torchex impl.Ref: https://docs.pytorch.org/docs/main/generated/torch.nn.functional.scaled_mm.html