Skip to content

Conversation

@crcrpar
Copy link
Collaborator

@crcrpar crcrpar commented Nov 6, 2025

What does this PR do?

As per title, this PR adds F.scaled_mm to thunder.torch and cover it with torchex impl.

Ref: https://docs.pytorch.org/docs/main/generated/torch.nn.functional.scaled_mm.html

Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

This PR adds support for torch.nn.functional.scaled_mm operation in Thunder. This operation performs scaled matrix multiplication, which is commonly used for FP8 quantized operations.

  • Adds a new scaled_mm function to thunder/torch/__init__.py with input validation and shape inference logic
  • Registers the implementation in thunder/executors/torchex.py to delegate to PyTorch
  • Adds comprehensive test coverage with tensor-wise, row-wise, and block-wise scaling tests

Reviewed Changes

Copilot reviewed 3 out of 3 changed files in this pull request and generated 3 comments.

File Description
thunder/torch/init.py Implements the scaled_mm symbol with parameter validation and output shape/dtype inference
thunder/executors/torchex.py Registers the torch executor implementation for scaled_mm
thunder/tests/test_ops.py Adds test helper functions and comprehensive tests for scaled_mm with different scaling strategies

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

ValueError,
)
for enum_value in values:
_ = int(enum_value.value) if hasattr(enum_value, "value") else int(enum_value)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we send a more detailed error here instead of a TypeError?

scale_recipe_a,
scale_b,
scale_recipe_b,
swizzle_a=None,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be helpful to add type annotation for scale_a, scale_recipe_a, scale_b, scale_recipe_b, swizzle_a and swizzle_b.

tensor_args.append(bias)
utils.check_same_device(*tensor_args)

result_dtype = to_dtype(output_dtype or torch.bfloat16)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the or torch.bfloat16 needed for the case if user explicitly passes output_dtype=None?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's still needed

if not hasattr(torch.nn.functional, "scaled_mm"):
pytest.skip("torch.nn.functional.scaled_mm is not available in this PyTorch build")
device = torch.device("cuda")
torch.manual_seed(0)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be better not to set the seed so that we test with different values.

@Lightning-AI Lightning-AI deleted a comment from kshitij12345 Nov 7, 2025
Signed-off-by: Masaki Kozuki <[email protected]>
Copy link
Collaborator

@kshitij12345 kshitij12345 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall looks good, just a few comments regarding the tests. Thanks!

Signed-off-by: Masaki Kozuki <[email protected]>
Copy link
Collaborator

@kshitij12345 kshitij12345 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks @crcrpar

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants