Skip to content

Commit 7d82ab6

Browse files
garroudfacebook-github-bot
authored andcommitted
add meta impl for int4 preshuffle kernels (#4384)
Summary: X-link: facebookresearch/FBGEMM#1458 Pull Request resolved: #4384 att. add fake impl to integrate with AOTI Reviewed By: jwfromm Differential Revision: D76834825 fbshipit-source-id: a20673eb3d48d83ae8a9e93456030eca9f2550b5
1 parent bf455b2 commit 7d82ab6

File tree

1 file changed

+24
-0
lines changed

1 file changed

+24
-0
lines changed

fbgemm_gpu/experimental/gen_ai/src/quantize/quantize.cpp

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -588,6 +588,28 @@ at::Tensor f8i4bf16_rowwise_meta(
588588
return Y;
589589
}
590590

591+
std::tuple<at::Tensor, at::Tensor> preshuffle_i4_meta(
592+
at::Tensor WQ,
593+
at::Tensor w_scale) {
594+
auto WS = at::empty_like(w_scale);
595+
if (w_scale.dtype() != at::kBFloat16) {
596+
WS = at::empty({w_scale.size(0), 8, w_scale.size(1)}, w_scale.options());
597+
}
598+
return {at::empty_like(WQ), WS};
599+
}
600+
601+
at::Tensor f8i4bf16_shuffled_meta(
602+
at::Tensor XQ, // FP8
603+
at::Tensor WQ, // INT4
604+
at::Tensor /* x_scale */,
605+
at::Tensor /* w_scale */,
606+
at::Tensor /* w_scale_group */) {
607+
const at::SymInt M = XQ.sym_size(0);
608+
const at::SymInt N = WQ.sym_size(0);
609+
auto Y = at::empty_symint({M, N}, XQ.options().dtype(at::kBFloat16));
610+
return Y;
611+
}
612+
591613
at::Tensor bf16i4bf16_rowwise_meta(
592614
at::Tensor X, // BF16
593615
at::Tensor W, // INT4
@@ -723,6 +745,8 @@ TORCH_LIBRARY_IMPL(fbgemm, Meta, m) {
723745
m.impl("bf16i4bf16_rowwise_batched", bf16i4bf16_rowwise_batched_meta);
724746
m.impl("f8f8bf16_lite", f8f8bf16_lite_meta);
725747
m.impl("scaled_fp4_quant", scaled_fp4_quant_meta);
748+
m.impl("preshuffle_i4", preshuffle_i4_meta);
749+
m.impl("f8i4bf16_shuffled", f8i4bf16_shuffled_meta);
726750
#endif
727751
#ifdef USE_ROCM
728752
m.impl("f8f8f16_rowwise", f8f8f16_rowwise_meta);

0 commit comments

Comments
 (0)