diff --git a/test/test_ops.py b/test/test_ops.py index 89512b673d..a46f5e4ff8 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -764,5 +764,69 @@ def test_swizzle_mm(): ) +EMBEDINGBAG_MULTIHOT_SIZES = [1, 2, 3, 10] +EMBEDINGBAG_BAG_SIZES = [1, 2, 128, 1024] +EMBEDINGBAG_VECTOR_SIZES = [1, 128, 512] +EMBEDINGBAG_INDEX_DTYPES = [torch.int64, torch.int32] + +EMBEDINGBAG_TEST_PARAMS = list( + itertools.product( + EMBEDINGBAG_MULTIHOT_SIZES, + EMBEDINGBAG_BAG_SIZES, + EMBEDINGBAG_VECTOR_SIZES, + EMBEDINGBAG_INDEX_DTYPES, + ) +) + + +@pytest.mark.skipif( + "CPU" not in torch._C._dispatch_dump("torchao::_scaled_embedding_bag"), + reason="cpp kernels not built", +) +@pytest.mark.parametrize( + "multi_hot, batch_size, vector_size, index_type", + EMBEDINGBAG_TEST_PARAMS, + ids=str, +) +def test_scaled_embedding_bag_cpu(multi_hot, batch_size, vector_size, index_type): + qtype = torch.float8_e4m3fn + dtype = torch.float32 + weight_scale = torch.tensor([2.0]) + include_last_offset = True + mode = "sum" + + if mode == "sum": + mode_enum = 0 + elif mode == "mean": + mode_enum = 1 + elif mode == "max": + mode_enum = 2 + indices = torch.randint(1000, (batch_size * multi_hot,)).to(index_type) + offsets = torch.arange(0, (batch_size + 1) * multi_hot, multi_hot).to(index_type) + + m = torch.nn.EmbeddingBag( + 1000, + vector_size, + mode=mode, + dtype=dtype, + include_last_offset=include_last_offset, + ) + fp8_weight = m.weight.data.to(qtype) + m.weight.data = fp8_weight.to(m.weight.dtype) + + with torch.no_grad(): + refe_out = m.forward(indices, offsets) * weight_scale + test_out = torch.ops.torchao._scaled_embedding_bag( + fp8_weight, + indices, + offsets, + weight_scale, + 1.0, + mode_enum, + include_last_offset, + ).to(dtype) + torch.testing.assert_close(refe_out, test_out, atol=1e-5, rtol=1e-5) + + if __name__ == "__main__": pytest.main(sys.argv) diff --git a/torchao/csrc/cpu/scaled_embedding_bag.cpp b/torchao/csrc/cpu/scaled_embedding_bag.cpp new file mode 100644 index 0000000000..1063f353c4 --- /dev/null +++ b/torchao/csrc/cpu/scaled_embedding_bag.cpp @@ -0,0 +1,183 @@ +#include +#include +#include +#include +#include +#include +#include + +namespace torchao { + +namespace { + +#if defined(CPU_CAPABILITY_AVX512) +static inline __m512 _mm512_load_e4m3_cvt_ps(const at::Float8_e4m3fn *x) { + __m512 o; + __m128i v = _mm_loadu_si128(reinterpret_cast(x)); + at::vec::CPU_CAPABILITY::cvtfp8e4m3_fp32(v, o); + return o; +} +#endif + +template +inline void _scaled_embedding_bag_krnl( + const int64_t bs_begin, const int64_t bs_end, const int64_t num_emb, + const int64_t emb_dim, const index_t last_offset, const index_t *indices, + const index_t *offsets, const at::Float8_e4m3fn *weight, const double scale, + float *result, const int64_t num_batch) { +#if defined(CPU_CAPABILITY_AVX512) + if (emb_dim % 128 == 0) { + constexpr int64_t block_dim = 128; + const int64_t num_blocks = emb_dim / block_dim; + __m512 scale_v = _mm512_set1_ps(scale); + for (int64_t b = bs_begin; b < bs_end; ++b) { + __m512 x0, x1, x2, x3, x4, x5, x6, x7; + int64_t start_idx = offsets[b]; + int64_t end_idx = ((b + 1) == num_batch && last_offset != -1) + ? last_offset + : offsets[b + 1]; + for (int64_t block_id = 0; block_id < num_blocks; block_id++) { + // load first indices + int64_t idx = indices[start_idx] * emb_dim + block_dim * block_id; + float *block_result = result + block_dim * block_id; + x0 = _mm512_load_e4m3_cvt_ps(&weight[idx]); + x1 = _mm512_load_e4m3_cvt_ps(&weight[idx + 16]); + x2 = _mm512_load_e4m3_cvt_ps(&weight[idx + 32]); + x3 = _mm512_load_e4m3_cvt_ps(&weight[idx + 48]); + x4 = _mm512_load_e4m3_cvt_ps(&weight[idx + 64]); + x5 = _mm512_load_e4m3_cvt_ps(&weight[idx + 80]); + x6 = _mm512_load_e4m3_cvt_ps(&weight[idx + 96]); + x7 = _mm512_load_e4m3_cvt_ps(&weight[idx + 112]); + for (int64_t j = start_idx + 1; j < end_idx; ++j) { + // add following idx + idx = indices[j] * emb_dim + block_dim * block_id; + x0 = _mm512_add_ps(x0, _mm512_load_e4m3_cvt_ps(&weight[idx])); + x1 = _mm512_add_ps(x1, _mm512_load_e4m3_cvt_ps(&weight[idx + 16])); + x2 = _mm512_add_ps(x2, _mm512_load_e4m3_cvt_ps(&weight[idx + 32])); + x3 = _mm512_add_ps(x3, _mm512_load_e4m3_cvt_ps(&weight[idx + 48])); + x4 = _mm512_add_ps(x4, _mm512_load_e4m3_cvt_ps(&weight[idx + 64])); + x5 = _mm512_add_ps(x5, _mm512_load_e4m3_cvt_ps(&weight[idx + 80])); + x6 = _mm512_add_ps(x6, _mm512_load_e4m3_cvt_ps(&weight[idx + 96])); + x7 = _mm512_add_ps(x7, _mm512_load_e4m3_cvt_ps(&weight[idx + 112])); + } + x0 = _mm512_mul_ps(x0, scale_v); + x1 = _mm512_mul_ps(x1, scale_v); + x2 = _mm512_mul_ps(x2, scale_v); + x3 = _mm512_mul_ps(x3, scale_v); + x4 = _mm512_mul_ps(x4, scale_v); + x5 = _mm512_mul_ps(x5, scale_v); + x6 = _mm512_mul_ps(x6, scale_v); + x7 = _mm512_mul_ps(x7, scale_v); + // store + _mm512_store_ps(block_result, x0); + _mm512_store_ps(block_result + 16, x1); + _mm512_store_ps(block_result + 32, x2); + _mm512_store_ps(block_result + 48, x3); + _mm512_store_ps(block_result + 64, x4); + _mm512_store_ps(block_result + 80, x5); + _mm512_store_ps(block_result + 96, x6); + _mm512_store_ps(block_result + 112, x7); + } + result += num_emb * emb_dim; + } + return; + } +#endif + for (int64_t b = bs_begin; b < bs_end; ++b) { + int64_t start_idx = offsets[b]; + int64_t end_idx = ((b + 1) == num_batch && last_offset != -1) + ? last_offset + : offsets[b + 1]; + for (int64_t d = 0; d < emb_dim; d++) { + int64_t idx = indices[start_idx] * emb_dim; + float value = float(weight[idx + d]); + for (int64_t j = start_idx + 1; j < end_idx; ++j) { + idx = indices[j] * emb_dim; + value += float(weight[idx + d]); + } + value = value * scale; + result[d] = value; + } + result += num_emb * emb_dim; + } +} + +template +void _scaled_embedding_bag(float *o_ptr, data_t *w_ptr, index_t *indices_ptr, + index_t *offsets_ptr, int64_t num_batch, + int64_t emb_dim, index_t last_offset, double w_scale, + double o_scale) { + constexpr int64_t b_block = 512; + const int64_t n_b_blocks = (num_batch - 1) / b_block + 1; + w_scale /= o_scale; + const int64_t num_emb = 1; +#pragma omp parallel for collapse(2) + for (int64_t b = 0; b < n_b_blocks; ++b) { + for (int64_t n = 0; n < num_emb; ++n) { + const int64_t bs_begin = b * b_block; + const int64_t bs_end = std::min(num_batch, (b + 1) * b_block); + float *r = &o_ptr[b * b_block * num_emb * emb_dim + n * emb_dim]; + // avoid offsets not include last batch + _scaled_embedding_bag_krnl(bs_begin, bs_end, num_emb, emb_dim, + last_offset, indices_ptr, offsets_ptr, w_ptr, + w_scale, r, num_batch); + } + } +} + +at::Tensor _scaled_embedding_bag_impl(const at::Tensor &qweight, + const at::Tensor &indices, + const at::Tensor &offsets, + const at::Tensor &w_scales, + double o_scale, const int64_t mode, + bool include_last_offset) { + // Only support include_last_offset == True and mode == + // at::native::EmbeddingBagMode::SUM + // TODO: Support more case + TORCH_CHECK(include_last_offset, + "_scaled_embedding_bag: only suppport include_last_offset"); + TORCH_CHECK(mode == at::native::EmbeddingBagMode::SUM, + "_scaled_embedding_bag: only suppport sum mode"); + int64_t batch_size = + include_last_offset ? offsets.size(0) - 1 : offsets.size(0); + int64_t emb_dim = qweight.size(1); + + auto index_type = indices.scalar_type(); + auto qtype = qweight.scalar_type(); + float w_scale = w_scales.data_ptr()[0]; + + TORCH_CHECK(indices.is_contiguous() && offsets.is_contiguous(), + "_scaled_embedding_bag: only accept contiguous input"); + TORCH_CHECK( + offsets.scalar_type() == index_type, + "_scaled_embedding_bag: index and offset must be of the same type"); + TORCH_CHECK(qweight.is_contiguous(), + "_scaled_embedding_bag: only accept contiguous weight"); + TORCH_CHECK(qweight.dim() == 2, + "_scaled_embedding_bag: only accept weight with dim == 2"); + TORCH_CHECK(qweight.scalar_type() == c10::ScalarType::Float8_e4m3fn, + "_scaled_embedding_bag: only support e4m3fn weight") + // handle last offsets + int64_t last_offset = indices.numel(); + + at::Tensor output = + at::empty({batch_size, emb_dim}, qweight.options().dtype(at::kFloat)); + AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "embeddingbag_cat", [&] { + at::Float8_e4m3fn *qweight_ptr = qweight.data_ptr(); + index_t *indices_ptr = indices.data_ptr(); + index_t *offsets_ptr = offsets.data_ptr(); + float *output_ptr = output.data_ptr(); + _scaled_embedding_bag( + output_ptr, qweight_ptr, indices_ptr, offsets_ptr, batch_size, emb_dim, + last_offset, w_scale, o_scale); + }); + return output; +} + +} // anonymous namespace + +TORCH_LIBRARY_IMPL(torchao, CPU, m) { + m.impl("torchao::_scaled_embedding_bag", &_scaled_embedding_bag_impl); +} + +} // namespace torchao \ No newline at end of file diff --git a/torchao/ops.py b/torchao/ops.py index 4b643cae98..b6348f90a5 100644 --- a/torchao/ops.py +++ b/torchao/ops.py @@ -68,6 +68,9 @@ lib.define( "da8w4_linear_cpu(Tensor input, Tensor input_scales, Tensor input_qzeros, Tensor weight, Tensor weight_scales, Tensor weight_qzeros, Tensor compensation, Tensor? bias, ScalarType output_dtype) -> Tensor" ) +lib.define( + "_scaled_embedding_bag(Tensor qweight, Tensor indices, Tensor offsets, Tensor weight_scale, float o_scale, int mode, bool include_last_offset) -> Tensor" +) def register_custom_op(name): @@ -1098,3 +1101,19 @@ def _( assert weight.dim() == 4 N = weight.size(0) * weight.size(3) * 2 return input.new_empty(*input.shape[:-1], N, dtype=out_dtype) + + +@register_custom_op("torchao::_scaled_embedding_bag") +def _( + qweight: Tensor, + indices: Tensor, + offsets: Tensor, + w_scales: Tensor, + o_scale: float, + mode: int, + include_last_offset: bool, +) -> Tensor: + # Only support include_last_offset == True + assert include_last_offset == True + batch_size = offsets.shape[0] - 1 + return qweight.new_empty(batch_size, qweight.shape[1], dtype=qweight.dtype)