Skip to content

[CPU][float8] Add QEmbeddingbag kernel #2686

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 64 additions & 0 deletions test/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -780,5 +780,69 @@ def test_swizzle_mm():
)


EMBEDINGBAG_MULTIHOT_SIZES = [1, 2, 3, 10]
EMBEDINGBAG_BAG_SIZES = [1, 2, 128]
EMBEDINGBAG_VECTOR_SIZES = [1, 128, 512]
EMBEDINGBAG_INDEX_DTYPES = [torch.int64, torch.int32]

EMBEDINGBAG_TEST_PARAMS = list(
itertools.product(
EMBEDINGBAG_MULTIHOT_SIZES,
EMBEDINGBAG_BAG_SIZES,
EMBEDINGBAG_VECTOR_SIZES,
EMBEDINGBAG_INDEX_DTYPES,
)
)


@pytest.mark.skipif(
"CPU" not in torch._C._dispatch_dump("torchao::qembeddingbag"),
reason="cpp kernels not built",
)
@pytest.mark.parametrize(
"multi_hot, batch_size, vector_size, index_type",
EMBEDINGBAG_TEST_PARAMS,
ids=str,
)
def test_embeddingbag_cpu(multi_hot, batch_size, vector_size, index_type):
qtype = torch.float8_e4m3fn
dtype = torch.float32
weight_scale = torch.tensor([2.0])
include_last_offset = True
mode = "sum"

if mode == "sum":
mode_enum = 0
elif mode == "mean":
mode_enum = 1
elif mode == "max":
mode_enum = 2
indices = torch.randint(1000, (batch_size * multi_hot,)).to(index_type)
offsets = torch.arange(0, (batch_size + 1) * multi_hot, multi_hot).to(index_type)

m = torch.nn.EmbeddingBag(
1000,
vector_size,
mode=mode,
dtype=dtype,
include_last_offset=include_last_offset,
)
fp8_weight = m.weight.data.to(qtype)
m.weight.data = fp8_weight.to(m.weight.dtype)

with torch.no_grad():
refe_out = m.forward(indices, offsets) * weight_scale
test_out = torch.ops.torchao.qembeddingbag(
fp8_weight,
indices,
offsets,
weight_scale,
1.0,
mode_enum,
include_last_offset,
).to(dtype)
torch.testing.assert_close(refe_out, test_out, atol=0, rtol=0)


if __name__ == "__main__":
pytest.main(sys.argv)
179 changes: 179 additions & 0 deletions torchao/csrc/cpu/qembeddingbag.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,179 @@
#include <ATen/cpu/vec/vec.h>
#include <ATen/cpu/vec/vec512/vec512_float8.h>
#include <ATen/native/CPUBlas.h>
#include <ATen/native/EmbeddingBag.h>
#include <c10/util/Float8_e4m3fn.h>
#include <c10/util/Unroll.h>
#include <torch/all.h>

namespace torchao {

namespace {

#if defined(CPU_CAPABILITY_AVX512)
static inline __m512 _mm512_load_e4m3_cvt_ps(const at::Float8_e4m3fn *x) {
__m512 o;
__m128i v = _mm_loadu_si128(reinterpret_cast<const __m128i *>(x));
at::vec::CPU_CAPABILITY::cvtfp8e4m3_fp32(v, o);
return o;
}
#endif

template <typename index_t>
inline void qembeddingbag_kern(const int64_t bs_begin, const int64_t bs_end,
const int64_t num_emb, const int64_t emb_dim,
const index_t last_offset,
const index_t *indices, const index_t *offsets,
const at::Float8_e4m3fn *weight,
const double scale, float *result) {
#if defined(CPU_CAPABILITY_AVX512)
if (emb_dim % 128 == 0) {
constexpr int64_t block_dim = 128;
const int64_t num_blocks = emb_dim / block_dim;
__m512 scale_v = _mm512_set1_ps(scale);
for (int64_t b = bs_begin; b < bs_end; ++b) {
__m512 x0, x1, x2, x3, x4, x5, x6, x7;
int64_t start_idx = offsets[b];
int64_t end_idx = ((b + 1) == bs_end && last_offset != -1)
? last_offset
: offsets[b + 1];
for (int64_t block_id = 0; block_id < num_blocks; block_id++) {
// load first indices
int64_t idx = indices[start_idx] * emb_dim + block_dim * block_id;
float *block_result = result + block_dim * block_id;
x0 = _mm512_load_e4m3_cvt_ps(&weight[idx]);
x1 = _mm512_load_e4m3_cvt_ps(&weight[idx + 16]);
x2 = _mm512_load_e4m3_cvt_ps(&weight[idx + 32]);
x3 = _mm512_load_e4m3_cvt_ps(&weight[idx + 48]);
x4 = _mm512_load_e4m3_cvt_ps(&weight[idx + 64]);
x5 = _mm512_load_e4m3_cvt_ps(&weight[idx + 80]);
x6 = _mm512_load_e4m3_cvt_ps(&weight[idx + 96]);
x7 = _mm512_load_e4m3_cvt_ps(&weight[idx + 112]);
for (int64_t j = start_idx + 1; j < end_idx; ++j) {
// add following idx
idx = indices[j] * emb_dim + block_dim * block_id;
x0 = _mm512_add_ps(x0, _mm512_load_e4m3_cvt_ps(&weight[idx]));
x1 = _mm512_add_ps(x1, _mm512_load_e4m3_cvt_ps(&weight[idx + 16]));
x2 = _mm512_add_ps(x2, _mm512_load_e4m3_cvt_ps(&weight[idx + 32]));
x3 = _mm512_add_ps(x3, _mm512_load_e4m3_cvt_ps(&weight[idx + 48]));
x4 = _mm512_add_ps(x4, _mm512_load_e4m3_cvt_ps(&weight[idx + 64]));
x5 = _mm512_add_ps(x5, _mm512_load_e4m3_cvt_ps(&weight[idx + 80]));
x6 = _mm512_add_ps(x6, _mm512_load_e4m3_cvt_ps(&weight[idx + 96]));
x7 = _mm512_add_ps(x7, _mm512_load_e4m3_cvt_ps(&weight[idx + 112]));
}
x0 = _mm512_mul_ps(x0, scale_v);
x1 = _mm512_mul_ps(x1, scale_v);
x2 = _mm512_mul_ps(x2, scale_v);
x3 = _mm512_mul_ps(x3, scale_v);
x4 = _mm512_mul_ps(x4, scale_v);
x5 = _mm512_mul_ps(x5, scale_v);
x6 = _mm512_mul_ps(x6, scale_v);
x7 = _mm512_mul_ps(x7, scale_v);
// store
_mm512_store_ps(block_result, x0);
_mm512_store_ps(block_result + 16, x1);
_mm512_store_ps(block_result + 32, x2);
_mm512_store_ps(block_result + 48, x3);
_mm512_store_ps(block_result + 64, x4);
_mm512_store_ps(block_result + 80, x5);
_mm512_store_ps(block_result + 96, x6);
_mm512_store_ps(block_result + 112, x7);
}
result += num_emb * emb_dim;
}
return;
}
#endif
for (int64_t b = bs_begin; b < bs_end; ++b) {
int64_t start_idx = offsets[b];
int64_t end_idx =
((b + 1) == bs_end && last_offset != -1) ? last_offset : offsets[b + 1];
for (int64_t d = 0; d < emb_dim; d++) {
int64_t idx = indices[start_idx] * emb_dim;
float value = float(weight[idx + d]);
for (int64_t j = start_idx + 1; j < end_idx; ++j) {
idx = indices[j] * emb_dim;
value += float(weight[idx + d]);
}
value = value * scale;
result[d] = value;
}
result += num_emb * emb_dim;
}
}

template <typename index_t, typename data_t>
void qembeddingbagcat(float *o_ptr, data_t *w_ptr, index_t *indices_ptr,
index_t *offsets_ptr, int64_t num_batch, int64_t emb_dim,
index_t last_offset, double w_scale, double o_scale) {
constexpr int64_t b_block = 512;
const int64_t n_b_blocks = (num_batch - 1) / b_block + 1;
w_scale /= o_scale;
const int64_t num_emb = 1;
#pragma omp parallel for collapse(2)
for (int64_t b = 0; b < n_b_blocks; ++b) {
for (int64_t n = 0; n < num_emb; ++n) {
const int64_t bs_begin = b * b_block;
const int64_t bs_end = std::min(num_batch, (b + 1) * b_block);
float *r = &o_ptr[b * b_block * num_emb * emb_dim + n * emb_dim];
// avoid offsets not include last batch
qembeddingbag_kern(bs_begin, bs_end, num_emb, emb_dim, last_offset,
indices_ptr, offsets_ptr, w_ptr, w_scale, r);
}
}
}

at::Tensor qembeddingbag_impl(const at::Tensor &qweight,
const at::Tensor &indices,
const at::Tensor &offsets,
const at::Tensor &w_scales, double o_scale,
const int64_t mode, bool include_last_offset) {
// Only support include_last_offset == True and mode ==
// at::native::EmbeddingBagMode::SUM
// TODO: Support more case
TORCH_CHECK(include_last_offset,
"qembeddingbag: only suppport include_last_offset");
TORCH_CHECK(mode == at::native::EmbeddingBagMode::SUM,
"qembeddingbag: only suppport sum mode");
int64_t batch_size =
include_last_offset ? offsets.size(0) - 1 : offsets.size(0);
int64_t emb_dim = qweight.size(1);

auto index_type = indices.scalar_type();
auto qtype = qweight.scalar_type();
float w_scale = w_scales.data_ptr<float>()[0];

TORCH_CHECK(indices.is_contiguous() && offsets.is_contiguous(),
"qembeddingbag: only accept contiguous input");
TORCH_CHECK(offsets.scalar_type() == index_type,
"qembeddingbag: index and offset must be of the same type");
TORCH_CHECK(qweight.is_contiguous(),
"qembeddingbag: only accept contiguous weight");
TORCH_CHECK(qweight.dim() == 2,
"qembeddingbag: only accept weight with dim == 2");
TORCH_CHECK(qweight.scalar_type() == c10::ScalarType::Float8_e4m3fn,
"qembeddingbag: only support e4m3fn weight")
// handle last offsets
int64_t last_offset = indices.numel();

at::Tensor output =
at::empty({batch_size, emb_dim}, qweight.options().dtype(at::kFloat));
AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "embeddingbag_cat", [&] {
at::Float8_e4m3fn *qweight_ptr = qweight.data_ptr<at::Float8_e4m3fn>();
index_t *indices_ptr = indices.data_ptr<index_t>();
index_t *offsets_ptr = offsets.data_ptr<index_t>();
float *output_ptr = output.data_ptr<float>();
qembeddingbagcat<index_t, at::Float8_e4m3fn>(
output_ptr, qweight_ptr, indices_ptr, offsets_ptr, batch_size, emb_dim,
last_offset, w_scale, o_scale);
});
return output;
}

} // anonymous namespace

TORCH_LIBRARY_IMPL(torchao, CPU, m) {
m.impl("torchao::qembeddingbag", &qembeddingbag_impl);
}

} // namespace torchao
19 changes: 19 additions & 0 deletions torchao/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,9 @@
lib.define(
"da8w4_linear_cpu(Tensor input, Tensor input_scales, Tensor input_qzeros, Tensor weight, Tensor weight_scales, Tensor weight_qzeros, Tensor compensation, Tensor? bias, ScalarType output_dtype) -> Tensor"
)
lib.define(
"qembeddingbag(Tensor qweight, Tensor indices, Tensor offsets, Tensor weight_scale, float o_scale, int mode, bool include_last_offset) -> Tensor"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jerryzh168 Thanks for reviewing. Yes, I think so, except that the implementation in this PR has limited functionality so far.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This operator is used for inference. So I did not add any parameters related to the gradient, including scale_grad_by_freq, sparse, per_sample_weights, padding_idx.

Copy link
Contributor

@jerryzh168 jerryzh168 Aug 15, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should add this to pytorch directly if that's the case, float8 is a native dtype in pytorch, so I think it makes most of the sense to just add the functionality there, we can error out in the op if some arg combination is not supported or invalid for float8

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Intel's platform has fp8 instructions. When we are ready, we hope to update this kernel based on fp8 instructions. As far as I know, the latest GCC is required. Is it difficult to support in PyTorch?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure, can you open an issue for this in pytorch/pytorch?

)


def register_custom_op(name):
Expand Down Expand Up @@ -1106,3 +1109,19 @@ def _(
assert weight.dim() == 4
N = weight.size(0) * weight.size(3) * 2
return input.new_empty(*input.shape[:-1], N, dtype=out_dtype)


@register_custom_op("torchao::qembeddingbag")
def _(
qweight: Tensor,
indices: Tensor,
offsets: Tensor,
w_scales: Tensor,
o_scale: float,
mode: int,
include_last_offset: bool,
) -> Tensor:
# Only support include_last_offset == True
assert include_last_offset == True
batch_size = offsets.shape[0] - 1
return qweight.new_empty(batch_size, qweight.shape[1], dtype=qweight.dtype)
Loading