@@ -588,6 +588,28 @@ at::Tensor f8i4bf16_rowwise_meta(
588
588
return Y;
589
589
}
590
590
591
+ std::tuple<at::Tensor, at::Tensor> preshuffle_i4_meta (
592
+ at::Tensor WQ,
593
+ at::Tensor w_scale) {
594
+ auto WS = at::empty_like (w_scale);
595
+ if (w_scale.dtype () != at::kBFloat16 ) {
596
+ WS = at::empty ({w_scale.size (0 ), 8 , w_scale.size (1 )}, w_scale.options ());
597
+ }
598
+ return {at::empty_like (WQ), WS};
599
+ }
600
+
601
+ at::Tensor f8i4bf16_shuffled_meta (
602
+ at::Tensor XQ, // FP8
603
+ at::Tensor WQ, // INT4
604
+ at::Tensor /* x_scale */ ,
605
+ at::Tensor /* w_scale */ ,
606
+ at::Tensor /* w_scale_group */ ) {
607
+ const at::SymInt M = XQ.sym_size (0 );
608
+ const at::SymInt N = WQ.sym_size (0 );
609
+ auto Y = at::empty_symint ({M, N}, XQ.options ().dtype (at::kBFloat16 ));
610
+ return Y;
611
+ }
612
+
591
613
at::Tensor bf16i4bf16_rowwise_meta (
592
614
at::Tensor X, // BF16
593
615
at::Tensor W, // INT4
@@ -723,6 +745,8 @@ TORCH_LIBRARY_IMPL(fbgemm, Meta, m) {
723
745
m.impl (" bf16i4bf16_rowwise_batched" , bf16i4bf16_rowwise_batched_meta);
724
746
m.impl (" f8f8bf16_lite" , f8f8bf16_lite_meta);
725
747
m.impl (" scaled_fp4_quant" , scaled_fp4_quant_meta);
748
+ m.impl (" preshuffle_i4" , preshuffle_i4_meta);
749
+ m.impl (" f8i4bf16_shuffled" , f8i4bf16_shuffled_meta);
726
750
#endif
727
751
#ifdef USE_ROCM
728
752
m.impl (" f8f8f16_rowwise" , f8f8f16_rowwise_meta);
0 commit comments