-
Notifications
You must be signed in to change notification settings - Fork 317
[CPU][float8] Add QEmbeddingbag kernel #2686
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
shiyang-weng
wants to merge
9
commits into
pytorch:main
Choose a base branch
from
shiyang-weng:wengshiy/embeddingbag_krnl
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
9 commits
Select commit
Hold shift + click to select a range
a695557
add embeddingbag kernel
shiyang-weng 025aa16
switch to use cvtfp8e4m3_fp32
shiyang-weng ab62099
improve code style
shiyang-weng badb85d
rm unused buf
shiyang-weng 8069e4a
mv ut to test/test_ops.py
shiyang-weng 9d0f7a5
refine kernel
shiyang-weng ae07dc6
add test case
shiyang-weng 72f5017
add more assert
shiyang-weng 0e10992
add more test case
shiyang-weng File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,179 @@ | ||
#include <ATen/cpu/vec/vec.h> | ||
#include <ATen/cpu/vec/vec512/vec512_float8.h> | ||
#include <ATen/native/CPUBlas.h> | ||
#include <ATen/native/EmbeddingBag.h> | ||
#include <c10/util/Float8_e4m3fn.h> | ||
#include <c10/util/Unroll.h> | ||
#include <torch/all.h> | ||
|
||
namespace torchao { | ||
|
||
namespace { | ||
|
||
#if defined(CPU_CAPABILITY_AVX512) | ||
static inline __m512 _mm512_load_e4m3_cvt_ps(const at::Float8_e4m3fn *x) { | ||
__m512 o; | ||
__m128i v = _mm_loadu_si128(reinterpret_cast<const __m128i *>(x)); | ||
at::vec::CPU_CAPABILITY::cvtfp8e4m3_fp32(v, o); | ||
return o; | ||
} | ||
#endif | ||
|
||
template <typename index_t> | ||
inline void qembeddingbag_kern(const int64_t bs_begin, const int64_t bs_end, | ||
const int64_t num_emb, const int64_t emb_dim, | ||
const index_t last_offset, | ||
const index_t *indices, const index_t *offsets, | ||
const at::Float8_e4m3fn *weight, | ||
const double scale, float *result) { | ||
#if defined(CPU_CAPABILITY_AVX512) | ||
if (emb_dim % 128 == 0) { | ||
constexpr int64_t block_dim = 128; | ||
const int64_t num_blocks = emb_dim / block_dim; | ||
__m512 scale_v = _mm512_set1_ps(scale); | ||
for (int64_t b = bs_begin; b < bs_end; ++b) { | ||
__m512 x0, x1, x2, x3, x4, x5, x6, x7; | ||
int64_t start_idx = offsets[b]; | ||
int64_t end_idx = ((b + 1) == bs_end && last_offset != -1) | ||
? last_offset | ||
: offsets[b + 1]; | ||
for (int64_t block_id = 0; block_id < num_blocks; block_id++) { | ||
// load first indices | ||
int64_t idx = indices[start_idx] * emb_dim + block_dim * block_id; | ||
float *block_result = result + block_dim * block_id; | ||
x0 = _mm512_load_e4m3_cvt_ps(&weight[idx]); | ||
x1 = _mm512_load_e4m3_cvt_ps(&weight[idx + 16]); | ||
x2 = _mm512_load_e4m3_cvt_ps(&weight[idx + 32]); | ||
x3 = _mm512_load_e4m3_cvt_ps(&weight[idx + 48]); | ||
x4 = _mm512_load_e4m3_cvt_ps(&weight[idx + 64]); | ||
x5 = _mm512_load_e4m3_cvt_ps(&weight[idx + 80]); | ||
x6 = _mm512_load_e4m3_cvt_ps(&weight[idx + 96]); | ||
x7 = _mm512_load_e4m3_cvt_ps(&weight[idx + 112]); | ||
for (int64_t j = start_idx + 1; j < end_idx; ++j) { | ||
// add following idx | ||
idx = indices[j] * emb_dim + block_dim * block_id; | ||
x0 = _mm512_add_ps(x0, _mm512_load_e4m3_cvt_ps(&weight[idx])); | ||
x1 = _mm512_add_ps(x1, _mm512_load_e4m3_cvt_ps(&weight[idx + 16])); | ||
x2 = _mm512_add_ps(x2, _mm512_load_e4m3_cvt_ps(&weight[idx + 32])); | ||
x3 = _mm512_add_ps(x3, _mm512_load_e4m3_cvt_ps(&weight[idx + 48])); | ||
x4 = _mm512_add_ps(x4, _mm512_load_e4m3_cvt_ps(&weight[idx + 64])); | ||
x5 = _mm512_add_ps(x5, _mm512_load_e4m3_cvt_ps(&weight[idx + 80])); | ||
x6 = _mm512_add_ps(x6, _mm512_load_e4m3_cvt_ps(&weight[idx + 96])); | ||
x7 = _mm512_add_ps(x7, _mm512_load_e4m3_cvt_ps(&weight[idx + 112])); | ||
} | ||
x0 = _mm512_mul_ps(x0, scale_v); | ||
x1 = _mm512_mul_ps(x1, scale_v); | ||
x2 = _mm512_mul_ps(x2, scale_v); | ||
x3 = _mm512_mul_ps(x3, scale_v); | ||
x4 = _mm512_mul_ps(x4, scale_v); | ||
x5 = _mm512_mul_ps(x5, scale_v); | ||
x6 = _mm512_mul_ps(x6, scale_v); | ||
x7 = _mm512_mul_ps(x7, scale_v); | ||
// store | ||
_mm512_store_ps(block_result, x0); | ||
_mm512_store_ps(block_result + 16, x1); | ||
_mm512_store_ps(block_result + 32, x2); | ||
_mm512_store_ps(block_result + 48, x3); | ||
_mm512_store_ps(block_result + 64, x4); | ||
_mm512_store_ps(block_result + 80, x5); | ||
_mm512_store_ps(block_result + 96, x6); | ||
_mm512_store_ps(block_result + 112, x7); | ||
} | ||
result += num_emb * emb_dim; | ||
} | ||
return; | ||
} | ||
#endif | ||
for (int64_t b = bs_begin; b < bs_end; ++b) { | ||
int64_t start_idx = offsets[b]; | ||
int64_t end_idx = | ||
((b + 1) == bs_end && last_offset != -1) ? last_offset : offsets[b + 1]; | ||
for (int64_t d = 0; d < emb_dim; d++) { | ||
int64_t idx = indices[start_idx] * emb_dim; | ||
float value = float(weight[idx + d]); | ||
for (int64_t j = start_idx + 1; j < end_idx; ++j) { | ||
idx = indices[j] * emb_dim; | ||
value += float(weight[idx + d]); | ||
} | ||
value = value * scale; | ||
result[d] = value; | ||
} | ||
result += num_emb * emb_dim; | ||
} | ||
} | ||
|
||
template <typename index_t, typename data_t> | ||
void qembeddingbagcat(float *o_ptr, data_t *w_ptr, index_t *indices_ptr, | ||
index_t *offsets_ptr, int64_t num_batch, int64_t emb_dim, | ||
index_t last_offset, double w_scale, double o_scale) { | ||
constexpr int64_t b_block = 512; | ||
const int64_t n_b_blocks = (num_batch - 1) / b_block + 1; | ||
w_scale /= o_scale; | ||
const int64_t num_emb = 1; | ||
#pragma omp parallel for collapse(2) | ||
for (int64_t b = 0; b < n_b_blocks; ++b) { | ||
for (int64_t n = 0; n < num_emb; ++n) { | ||
const int64_t bs_begin = b * b_block; | ||
const int64_t bs_end = std::min(num_batch, (b + 1) * b_block); | ||
float *r = &o_ptr[b * b_block * num_emb * emb_dim + n * emb_dim]; | ||
// avoid offsets not include last batch | ||
qembeddingbag_kern(bs_begin, bs_end, num_emb, emb_dim, last_offset, | ||
indices_ptr, offsets_ptr, w_ptr, w_scale, r); | ||
} | ||
} | ||
} | ||
|
||
at::Tensor qembeddingbag_impl(const at::Tensor &qweight, | ||
const at::Tensor &indices, | ||
const at::Tensor &offsets, | ||
const at::Tensor &w_scales, double o_scale, | ||
const int64_t mode, bool include_last_offset) { | ||
// Only support include_last_offset == True and mode == | ||
// at::native::EmbeddingBagMode::SUM | ||
// TODO: Support more case | ||
TORCH_CHECK(include_last_offset, | ||
"qembeddingbag: only suppport include_last_offset"); | ||
TORCH_CHECK(mode == at::native::EmbeddingBagMode::SUM, | ||
"qembeddingbag: only suppport sum mode"); | ||
int64_t batch_size = | ||
include_last_offset ? offsets.size(0) - 1 : offsets.size(0); | ||
int64_t emb_dim = qweight.size(1); | ||
|
||
auto index_type = indices.scalar_type(); | ||
auto qtype = qweight.scalar_type(); | ||
float w_scale = w_scales.data_ptr<float>()[0]; | ||
|
||
TORCH_CHECK(indices.is_contiguous() && offsets.is_contiguous(), | ||
"qembeddingbag: only accept contiguous input"); | ||
TORCH_CHECK(offsets.scalar_type() == index_type, | ||
"qembeddingbag: index and offset must be of the same type"); | ||
TORCH_CHECK(qweight.is_contiguous(), | ||
"qembeddingbag: only accept contiguous weight"); | ||
TORCH_CHECK(qweight.dim() == 2, | ||
"qembeddingbag: only accept weight with dim == 2"); | ||
TORCH_CHECK(qweight.scalar_type() == c10::ScalarType::Float8_e4m3fn, | ||
"qembeddingbag: only support e4m3fn weight") | ||
// handle last offsets | ||
int64_t last_offset = indices.numel(); | ||
|
||
at::Tensor output = | ||
at::empty({batch_size, emb_dim}, qweight.options().dtype(at::kFloat)); | ||
AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "embeddingbag_cat", [&] { | ||
at::Float8_e4m3fn *qweight_ptr = qweight.data_ptr<at::Float8_e4m3fn>(); | ||
index_t *indices_ptr = indices.data_ptr<index_t>(); | ||
index_t *offsets_ptr = offsets.data_ptr<index_t>(); | ||
float *output_ptr = output.data_ptr<float>(); | ||
qembeddingbagcat<index_t, at::Float8_e4m3fn>( | ||
output_ptr, qweight_ptr, indices_ptr, offsets_ptr, batch_size, emb_dim, | ||
last_offset, w_scale, o_scale); | ||
}); | ||
return output; | ||
} | ||
|
||
} // anonymous namespace | ||
|
||
TORCH_LIBRARY_IMPL(torchao, CPU, m) { | ||
m.impl("torchao::qembeddingbag", &qembeddingbag_impl); | ||
} | ||
|
||
} // namespace torchao |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is this the same as https://github.com/pytorch/pytorch/blob/371eacb2ae4ecdabc52ea4634ed21558df2f3bab/aten/src/ATen/native/native_functions.yaml#L2368C1-L2369C1? with the only difference of qweight being float8?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@jerryzh168 Thanks for reviewing. Yes, I think so, except that the implementation in this PR has limited functionality so far.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This operator is used for inference. So I did not add any parameters related to the gradient, including scale_grad_by_freq, sparse, per_sample_weights, padding_idx.
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we should add this to pytorch directly if that's the case, float8 is a native dtype in pytorch, so I think it makes most of the sense to just add the functionality there, we can error out in the op if some arg combination is not supported or invalid for float8
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Intel's platform has fp8 instructions. When we are ready, we hope to update this kernel based on fp8 instructions. As far as I know, the latest GCC is required. Is it difficult to support in PyTorch?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure, can you open an issue for this in pytorch/pytorch?