Skip to content
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 64 additions & 0 deletions test/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -780,5 +780,69 @@ def test_swizzle_mm():
)


EMBEDINGBAG_MULTIHOT_SIZES = [1, 2, 3, 10]
EMBEDINGBAG_BAG_SIZES = [1, 2, 128]
EMBEDINGBAG_VECTOR_SIZES = [1, 128, 512]
EMBEDINGBAG_INDEX_DTYPES = [torch.int64, torch.int32]

EMBEDINGBAG_TEST_PARAMS = list(
itertools.product(
EMBEDINGBAG_MULTIHOT_SIZES,
EMBEDINGBAG_BAG_SIZES,
EMBEDINGBAG_VECTOR_SIZES,
EMBEDINGBAG_INDEX_DTYPES,
)
)


@pytest.mark.skipif(
"CPU" not in torch._C._dispatch_dump("torchao::qembeddingbag"),
reason="cpp kernels not built",
)
@pytest.mark.parametrize(
"multi_hot, batch_size, vector_size, index_type",
EMBEDINGBAG_TEST_PARAMS,
ids=str,
)
def test_embeddingbag_cpu(multi_hot, batch_size, vector_size, index_type):
qtype = torch.float8_e4m3fn
dtype = torch.float32
weight_scale = torch.tensor([2.0])
include_last_offset = True
mode = "sum"

if mode == "sum":
mode_enum = 0
elif mode == "mean":
mode_enum = 1
elif mode == "max":
mode_enum = 2
indices = torch.randint(1000, (batch_size * multi_hot,)).to(index_type)
offsets = torch.arange(0, (batch_size + 1) * multi_hot, multi_hot).to(index_type)

m = torch.nn.EmbeddingBag(
1000,
vector_size,
mode=mode,
dtype=dtype,
include_last_offset=include_last_offset,
)
fp8_weight = m.weight.data.to(qtype)
m.weight.data = fp8_weight.to(m.weight.dtype)

with torch.no_grad():
refe_out = m.forward(indices, offsets) * weight_scale
test_out = torch.ops.torchao.qembeddingbag(
fp8_weight,
indices,
offsets,
weight_scale,
1.0,
mode_enum,
include_last_offset,
).to(dtype)
torch.testing.assert_close(refe_out, test_out, atol=0, rtol=0)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this too strict?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

changed to 1e-5



if __name__ == "__main__":
pytest.main(sys.argv)
179 changes: 179 additions & 0 deletions torchao/csrc/cpu/qembeddingbag.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,179 @@
#include <ATen/cpu/vec/vec.h>
#include <ATen/cpu/vec/vec512/vec512_float8.h>
#include <ATen/native/CPUBlas.h>
#include <ATen/native/EmbeddingBag.h>
#include <c10/util/Float8_e4m3fn.h>
#include <c10/util/Unroll.h>
#include <torch/all.h>

namespace torchao {

namespace {

#if defined(CPU_CAPABILITY_AVX512)
static inline __m512 _mm512_load_e4m3_cvt_ps(const at::Float8_e4m3fn *x) {
__m512 o;
__m128i v = _mm_loadu_si128(reinterpret_cast<const __m128i *>(x));
at::vec::CPU_CAPABILITY::cvtfp8e4m3_fp32(v, o);
return o;
}
#endif

template <typename index_t>
inline void qembeddingbag_kern(const int64_t bs_begin, const int64_t bs_end,
const int64_t num_emb, const int64_t emb_dim,
const index_t last_offset,
const index_t *indices, const index_t *offsets,
const at::Float8_e4m3fn *weight,
const double scale, float *result) {
#if defined(CPU_CAPABILITY_AVX512)
if (emb_dim % 128 == 0) {
constexpr int64_t block_dim = 128;
const int64_t num_blocks = emb_dim / block_dim;
__m512 scale_v = _mm512_set1_ps(scale);
for (int64_t b = bs_begin; b < bs_end; ++b) {
__m512 x0, x1, x2, x3, x4, x5, x6, x7;
int64_t start_idx = offsets[b];
int64_t end_idx = ((b + 1) == bs_end && last_offset != -1)
? last_offset
: offsets[b + 1];
for (int64_t block_id = 0; block_id < num_blocks; block_id++) {
// load first indices
int64_t idx = indices[start_idx] * emb_dim + block_dim * block_id;
float *block_result = result + block_dim * block_id;
x0 = _mm512_load_e4m3_cvt_ps(&weight[idx]);
x1 = _mm512_load_e4m3_cvt_ps(&weight[idx + 16]);
x2 = _mm512_load_e4m3_cvt_ps(&weight[idx + 32]);
x3 = _mm512_load_e4m3_cvt_ps(&weight[idx + 48]);
x4 = _mm512_load_e4m3_cvt_ps(&weight[idx + 64]);
x5 = _mm512_load_e4m3_cvt_ps(&weight[idx + 80]);
x6 = _mm512_load_e4m3_cvt_ps(&weight[idx + 96]);
x7 = _mm512_load_e4m3_cvt_ps(&weight[idx + 112]);
for (int64_t j = start_idx + 1; j < end_idx; ++j) {
// add following idx
idx = indices[j] * emb_dim + block_dim * block_id;
x0 = _mm512_add_ps(x0, _mm512_load_e4m3_cvt_ps(&weight[idx]));
x1 = _mm512_add_ps(x1, _mm512_load_e4m3_cvt_ps(&weight[idx + 16]));
x2 = _mm512_add_ps(x2, _mm512_load_e4m3_cvt_ps(&weight[idx + 32]));
x3 = _mm512_add_ps(x3, _mm512_load_e4m3_cvt_ps(&weight[idx + 48]));
x4 = _mm512_add_ps(x4, _mm512_load_e4m3_cvt_ps(&weight[idx + 64]));
x5 = _mm512_add_ps(x5, _mm512_load_e4m3_cvt_ps(&weight[idx + 80]));
x6 = _mm512_add_ps(x6, _mm512_load_e4m3_cvt_ps(&weight[idx + 96]));
x7 = _mm512_add_ps(x7, _mm512_load_e4m3_cvt_ps(&weight[idx + 112]));
}
x0 = _mm512_mul_ps(x0, scale_v);
x1 = _mm512_mul_ps(x1, scale_v);
x2 = _mm512_mul_ps(x2, scale_v);
x3 = _mm512_mul_ps(x3, scale_v);
x4 = _mm512_mul_ps(x4, scale_v);
x5 = _mm512_mul_ps(x5, scale_v);
x6 = _mm512_mul_ps(x6, scale_v);
x7 = _mm512_mul_ps(x7, scale_v);
// store
_mm512_store_ps(block_result, x0);
_mm512_store_ps(block_result + 16, x1);
_mm512_store_ps(block_result + 32, x2);
_mm512_store_ps(block_result + 48, x3);
_mm512_store_ps(block_result + 64, x4);
_mm512_store_ps(block_result + 80, x5);
_mm512_store_ps(block_result + 96, x6);
_mm512_store_ps(block_result + 112, x7);
}
result += num_emb * emb_dim;
}
return;
}
#endif
for (int64_t b = bs_begin; b < bs_end; ++b) {
int64_t start_idx = offsets[b];
int64_t end_idx =
((b + 1) == bs_end && last_offset != -1) ? last_offset : offsets[b + 1];
for (int64_t d = 0; d < emb_dim; d++) {
int64_t idx = indices[start_idx] * emb_dim;
float value = float(weight[idx + d]);
for (int64_t j = start_idx + 1; j < end_idx; ++j) {
idx = indices[j] * emb_dim;
value += float(weight[idx + d]);
}
value = value * scale;
result[d] = value;
}
result += num_emb * emb_dim;
}
}

template <typename index_t, typename data_t>
void qembeddingbagcat(float *o_ptr, data_t *w_ptr, index_t *indices_ptr,
index_t *offsets_ptr, int64_t num_batch, int64_t emb_dim,
index_t last_offset, double w_scale, double o_scale) {
constexpr int64_t b_block = 512;
const int64_t n_b_blocks = (num_batch - 1) / b_block + 1;
w_scale /= o_scale;
const int64_t num_emb = 1;
#pragma omp parallel for collapse(2)
for (int64_t b = 0; b < n_b_blocks; ++b) {
for (int64_t n = 0; n < num_emb; ++n) {
const int64_t bs_begin = b * b_block;
const int64_t bs_end = std::min(num_batch, (b + 1) * b_block);
float *r = &o_ptr[b * b_block * num_emb * emb_dim + n * emb_dim];
// avoid offsets not include last batch
qembeddingbag_kern(bs_begin, bs_end, num_emb, emb_dim, last_offset,
indices_ptr, offsets_ptr, w_ptr, w_scale, r);
}
}
}

at::Tensor qembeddingbag_impl(const at::Tensor &qweight,
const at::Tensor &indices,
const at::Tensor &offsets,
const at::Tensor &w_scales, double o_scale,
const int64_t mode, bool include_last_offset) {
// Only support include_last_offset == True and mode ==
// at::native::EmbeddingBagMode::SUM
// TODO: Support more case
TORCH_CHECK(include_last_offset,
"qembeddingbag: only suppport include_last_offset");
TORCH_CHECK(mode == at::native::EmbeddingBagMode::SUM,
"qembeddingbag: only suppport sum mode");
int64_t batch_size =
include_last_offset ? offsets.size(0) - 1 : offsets.size(0);
int64_t emb_dim = qweight.size(1);

auto index_type = indices.scalar_type();
auto qtype = qweight.scalar_type();
float w_scale = w_scales.data_ptr<float>()[0];

TORCH_CHECK(indices.is_contiguous() && offsets.is_contiguous(),
"qembeddingbag: only accept contiguous input");
TORCH_CHECK(offsets.scalar_type() == index_type,
"qembeddingbag: index and offset must be of the same type");
TORCH_CHECK(qweight.is_contiguous(),
"qembeddingbag: only accept contiguous weight");
TORCH_CHECK(qweight.dim() == 2,
"qembeddingbag: only accept weight with dim == 2");
TORCH_CHECK(qweight.scalar_type() == c10::ScalarType::Float8_e4m3fn,
"qembeddingbag: only support e4m3fn weight")
// handle last offsets
int64_t last_offset = indices.numel();

at::Tensor output =
at::empty({batch_size, emb_dim}, qweight.options().dtype(at::kFloat));
AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "embeddingbag_cat", [&] {
at::Float8_e4m3fn *qweight_ptr = qweight.data_ptr<at::Float8_e4m3fn>();
index_t *indices_ptr = indices.data_ptr<index_t>();
index_t *offsets_ptr = offsets.data_ptr<index_t>();
float *output_ptr = output.data_ptr<float>();
qembeddingbagcat<index_t, at::Float8_e4m3fn>(
output_ptr, qweight_ptr, indices_ptr, offsets_ptr, batch_size, emb_dim,
last_offset, w_scale, o_scale);
});
return output;
}

} // anonymous namespace

TORCH_LIBRARY_IMPL(torchao, CPU, m) {
m.impl("torchao::qembeddingbag", &qembeddingbag_impl);
}

} // namespace torchao
19 changes: 19 additions & 0 deletions torchao/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,9 @@
lib.define(
"da8w4_linear_cpu(Tensor input, Tensor input_scales, Tensor input_qzeros, Tensor weight, Tensor weight_scales, Tensor weight_qzeros, Tensor compensation, Tensor? bias, ScalarType output_dtype) -> Tensor"
)
lib.define(
"qembeddingbag(Tensor qweight, Tensor indices, Tensor offsets, Tensor weight_scale, float o_scale, int mode, bool include_last_offset) -> Tensor"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jerryzh168 Thanks for reviewing. Yes, I think so, except that the implementation in this PR has limited functionality so far.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This operator is used for inference. So I did not add any parameters related to the gradient, including scale_grad_by_freq, sparse, per_sample_weights, padding_idx.

Copy link
Contributor

@jerryzh168 jerryzh168 Aug 15, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should add this to pytorch directly if that's the case, float8 is a native dtype in pytorch, so I think it makes most of the sense to just add the functionality there, we can error out in the op if some arg combination is not supported or invalid for float8

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Intel's platform has fp8 instructions. When we are ready, we hope to update this kernel based on fp8 instructions. As far as I know, the latest GCC is required. Is it difficult to support in PyTorch?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So, since this PR adds a quantized version of this op, do you think it better to be added in Torchao rather than in torch core? Thanks.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah my question is can this be implemented with extending the embedding_bag op in pytorch and do the scaling in torchao? or will performance be a concern here

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a memory bound operator. Repeated reading and writing will lead to significant performance degradation. For example, if we originally need to read and write once(this situation will also occur many times for DLRM), we will need to read and write twice after do the scaling separately, and the performance will be reduced by half.

Copy link
Contributor

@jerryzh168 jerryzh168 Aug 22, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, sounds good, maybe rename this to _scaled_embedding_bag to follow these ops: https://github.com/pytorch/pytorch/blob/31a41daff49f2cde941d8b9e35cb2eaeeb606c0d/aten/src/ATen/native/native_functions.yaml#L7135

using _ to indicating it's prototype op since you may want to update the arg list expand hardware coverage etc. later

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

)


def register_custom_op(name):
Expand Down Expand Up @@ -1106,3 +1109,19 @@ def _(
assert weight.dim() == 4
N = weight.size(0) * weight.size(3) * 2
return input.new_empty(*input.shape[:-1], N, dtype=out_dtype)


@register_custom_op("torchao::qembeddingbag")
def _(
qweight: Tensor,
indices: Tensor,
offsets: Tensor,
w_scales: Tensor,
o_scale: float,
mode: int,
include_last_offset: bool,
) -> Tensor:
# Only support include_last_offset == True
assert include_last_offset == True
batch_size = offsets.shape[0] - 1
return qweight.new_empty(batch_size, qweight.shape[1], dtype=qweight.dtype)
Loading