-
Notifications
You must be signed in to change notification settings - Fork 317
[CPU][float8] Add QEmbeddingbag kernel #2686
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
[CPU][float8] Add QEmbeddingbag kernel #2686
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/2686
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 0e10992 with merge base 7dbc816 ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
test/quantization/test_quant_api.py
Outdated
"CPU" not in torch._C._dispatch_dump("torchao::qembeddingbag"), | ||
reason="cpp kernels not built", | ||
) | ||
def test_embeddingbag_cpu(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the test should be added here I think: https://github.com/pytorch/ao/blob/main/test/test_ops.py
This comment was marked as outdated.
This comment was marked as outdated.
❌ 🤖 pytorchbot command failed:
Try |
@pytorchbot label "topic: new feature" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. Have you run some benchmark to ensure it's not too slow?
@jerryzh168 Could you help review this pr |
@@ -70,6 +70,9 @@ | |||
lib.define( | |||
"da8w4_linear_cpu(Tensor input, Tensor input_scales, Tensor input_qzeros, Tensor weight, Tensor weight_scales, Tensor weight_qzeros, Tensor compensation, Tensor? bias, ScalarType output_dtype) -> Tensor" | |||
) | |||
lib.define( | |||
"qembeddingbag(Tensor qweight, Tensor indices, Tensor offsets, Tensor weight_scale, float o_scale, int mode, bool include_last_offset) -> Tensor" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is this the same as https://github.com/pytorch/pytorch/blob/371eacb2ae4ecdabc52ea4634ed21558df2f3bab/aten/src/ATen/native/native_functions.yaml#L2368C1-L2369C1? with the only difference of qweight being float8?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@jerryzh168 Thanks for reviewing. Yes, I think so, except that the implementation in this PR has limited functionality so far.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This operator is used for inference. So I did not add any parameters related to the gradient, including scale_grad_by_freq, sparse, per_sample_weights, padding_idx.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we should add this to pytorch directly if that's the case, float8 is a native dtype in pytorch, so I think it makes most of the sense to just add the functionality there, we can error out in the op if some arg combination is not supported or invalid for float8
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Intel's platform has fp8 instructions. When we are ready, we hope to update this kernel based on fp8 instructions. As far as I know, the latest GCC is required. Is it difficult to support in PyTorch?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure, can you open an issue for this in pytorch/pytorch?
Implemented FP8 QEmbeddingBag on CPU, currently supporting:
include_last_offset=True
mode="sum"
Next steps