Skip to content

[mlir][ArithToAMDGPU] Use native packing support #150342

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Jul 24, 2025

Conversation

krzysz00
Copy link
Contributor

The current arith-to-amdgpu patterns for scaling_extf and scaling_truncf don't take full advantage of the native packing ability of the intrinsics being targetted. Scaling extension takes the location of the two elements to be extended as a constant argument (byte for fp4, half for fp8), and scaling truncation takes a 32-bit input register and a byte or half to write the truncated values to.

Not using these features would cause excess unneeded register pressure. This PR resolves the inefficiency.

It also adds a test for the expected usecase of extending or truncateting a block of 32 values to/from fp4 with a uniform scale to ensure that this usage has a minimal amount of vector shuffling.

The current arith-to-amdgpu patterns for scaling_extf and
scaling_truncf don't take full advantage of the native packing ability
of the intrinsics being targetted. Scaling extension takes the
location of the two elements to be extended as a constant
argument (byte for fp4, half for fp8), and scaling truncation takes a
32-bit input register and a byte or half to write the truncated values
to.

Not using these features would cause excess unneeded register
pressure. This PR resolves the inefficiency.

It also adds a test for the expected usecase of extending or
truncateting a block of 32 values to/from fp4 with a uniform scale to
ensure that this usage has a minimal amount of vector shuffling.
@llvmbot
Copy link
Member

llvmbot commented Jul 24, 2025

@llvm/pr-subscribers-backend-amdgpu

@llvm/pr-subscribers-mlir-gpu

Author: Krzysztof Drewniak (krzysz00)

Changes

The current arith-to-amdgpu patterns for scaling_extf and scaling_truncf don't take full advantage of the native packing ability of the intrinsics being targetted. Scaling extension takes the location of the two elements to be extended as a constant argument (byte for fp4, half for fp8), and scaling truncation takes a 32-bit input register and a byte or half to write the truncated values to.

Not using these features would cause excess unneeded register pressure. This PR resolves the inefficiency.

It also adds a test for the expected usecase of extending or truncateting a block of 32 values to/from fp4 with a uniform scale to ensure that this usage has a minimal amount of vector shuffling.


Patch is 28.80 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/150342.diff

3 Files Affected:

  • (modified) mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp (+62-38)
  • (modified) mlir/test/Conversion/ArithToAMDGPU/scaling-extf.mlir (+24-22)
  • (modified) mlir/test/Conversion/ArithToAMDGPU/scaling-truncf.mlir (+27-33)
diff --git a/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp b/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
index 8c68b57877c35..4cf80167b20c2 100644
--- a/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
+++ b/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
@@ -449,7 +449,8 @@ LogicalResult
 ScalingExtFRewritePattern::matchAndRewrite(arith::ScalingExtFOp op,
                                            PatternRewriter &rewriter) const {
   Location loc = op.getLoc();
-  constexpr int64_t opWidth = 2;
+  constexpr int64_t opOutWidth = 2;
+  constexpr int64_t opInWidth = 8;
 
   Value in = op.getIn();
   Value scale = op.getScale();
@@ -473,7 +474,7 @@ ScalingExtFRewritePattern::matchAndRewrite(arith::ScalingExtFOp op,
   else if (scaleType.getIntOrFloatBitWidth() > 32)
     scale = arith::TruncFOp::create(rewriter, loc, scaleF32Type, scale);
 
-  VectorType extScaleResultType = VectorType::get(opWidth, outType);
+  VectorType extScaleResultType = VectorType::get(opOutWidth, outType);
 
   if (!outVecType) {
     Value inCast = vector::BroadcastOp::create(rewriter, loc,
@@ -487,10 +488,11 @@ ScalingExtFRewritePattern::matchAndRewrite(arith::ScalingExtFOp op,
 
   VectorType inVecType = cast<VectorType>(in.getType());
   Value origScale = getOriginalVectorValue(op.getScale());
+  VectorType origScaleVecType = dyn_cast<VectorType>(origScale.getType());
 
   ArrayRef<int64_t> inShape = inVecType.getShape();
   SmallVector<int64_t> originalScaleShape;
-  if (auto origScaleVecType = dyn_cast<VectorType>(origScale.getType()))
+  if (origScaleVecType)
     llvm::append_range(originalScaleShape, origScaleVecType.getShape());
 
   originalScaleShape.insert(originalScaleShape.end(),
@@ -524,19 +526,26 @@ ScalingExtFRewritePattern::matchAndRewrite(arith::ScalingExtFOp op,
     Value blockResult =
         rewriter.createOrFold<vector::BroadcastOp>(loc, blockResultType, zero);
 
-    for (int64_t i = 0, sliceWidth = std::min(opWidth, blockSize - i);
+    for (int64_t i = 0, inSliceWidth = std::min(opInWidth, blockSize - i);
          i < blockSize;
-         i += sliceWidth, sliceWidth = std::min(opWidth, blockSize - i)) {
-      Value slice = vector::ExtractStridedSliceOp::create(
-          rewriter, loc, block1D, i, sliceWidth, 1);
-      // TODO: replace this with non-packed ScaledExtOp for sliceWidth == 1
-      Value scaleExt = amdgpu::ScaledExtPackedOp::create(
-          rewriter, loc, extScaleResultType, slice, uniformScale, 0);
-      if (sliceWidth != opWidth)
-        scaleExt = vector::ExtractStridedSliceOp::create(
-            rewriter, loc, scaleExt, 0, sliceWidth, 1);
-      blockResult = vector::InsertStridedSliceOp::create(
-          rewriter, loc, scaleExt, blockResult, i, 1);
+         i += inSliceWidth, inSliceWidth = std::min(opInWidth, blockSize - i)) {
+      Value inSlice = vector::ExtractStridedSliceOp::create(
+          rewriter, loc, block1D, i, inSliceWidth, 1);
+      for (int64_t j = 0,
+                   outSliceWidth = std::min(opOutWidth, inSliceWidth - j);
+           j < inSliceWidth; j += outSliceWidth,
+                   outSliceWidth = std::min(opOutWidth, inSliceWidth - j)) {
+        // TODO: replace this with non-packed ScaledExtOp for sliceWidth == 1
+        Value scaleExt = amdgpu::ScaledExtPackedOp::create(
+            rewriter, loc, extScaleResultType, inSlice, uniformScale,
+            j / opOutWidth);
+        if (outSliceWidth < opOutWidth) {
+          scaleExt = vector::ExtractStridedSliceOp::create(
+              rewriter, loc, scaleExt, 0, outSliceWidth, 1);
+        }
+        blockResult = vector::InsertStridedSliceOp::create(
+            rewriter, loc, scaleExt, blockResult, i + j, 1);
+      }
     }
 
     VectorType resultType = VectorType::get(ratio, outType);
@@ -555,7 +564,7 @@ LogicalResult
 ScalingTruncFRewritePattern::matchAndRewrite(arith::ScalingTruncFOp op,
                                              PatternRewriter &rewriter) const {
   Location loc = op.getLoc();
-  constexpr int64_t opWidth = 2;
+  constexpr int64_t opInWidth = 2;
 
   Value in = op.getIn();
   Value scale = op.getScale();
@@ -568,7 +577,6 @@ ScalingTruncFRewritePattern::matchAndRewrite(arith::ScalingTruncFOp op,
 
   VectorType outVecType = dyn_cast<VectorType>(out.getType());
   VectorType scaleVecType = dyn_cast<VectorType>(scale.getType());
-
   if (outVecType && outVecType.isScalable())
     return failure();
 
@@ -581,8 +589,8 @@ ScalingTruncFRewritePattern::matchAndRewrite(arith::ScalingTruncFOp op,
 
   Value zero = arith::ConstantOp::create(rewriter, loc, outType,
                                          rewriter.getFloatAttr(outType, 0.0));
-  unsigned numPackedElem = 32 / outType.getIntOrFloatBitWidth();
-  VectorType truncScaleResultType = VectorType::get(numPackedElem, outType);
+  int64_t opOutWidth = 32 / outType.getIntOrFloatBitWidth();
+  VectorType truncScaleResultType = VectorType::get(opOutWidth, outType);
 
   if (!outVecType) {
     Type inVecType = VectorType::get(1, inType);
@@ -598,16 +606,16 @@ ScalingTruncFRewritePattern::matchAndRewrite(arith::ScalingTruncFOp op,
 
   VectorType inVecType = cast<VectorType>(in.getType());
   Value origScale = getOriginalVectorValue(op.getScale());
+  VectorType origScaleVecType = dyn_cast<VectorType>(origScale.getType());
 
   ArrayRef<int64_t> inShape = inVecType.getShape();
-  SmallVector<int64_t> originalScaleShape;
-  if (auto origScaleVecType = dyn_cast<VectorType>(origScale.getType()))
-    llvm::append_range(originalScaleShape, origScaleVecType.getShape());
+  SmallVector<int64_t> scaleShape;
+  if (origScaleVecType)
+    llvm::append_range(scaleShape, origScaleVecType.getShape());
 
-  originalScaleShape.insert(originalScaleShape.end(),
-                            inShape.size() - originalScaleShape.size(), 1);
+  scaleShape.insert(scaleShape.end(), inShape.size() - scaleShape.size(), 1);
 
-  auto maybeRatio = computeShapeRatio(inShape, originalScaleShape);
+  auto maybeRatio = computeShapeRatio(inShape, scaleShape);
   assert(maybeRatio &&
          "failed to derive block size from broadcast or splat operation");
 
@@ -633,20 +641,36 @@ ScalingTruncFRewritePattern::matchAndRewrite(arith::ScalingTruncFOp op,
     Value blockResult =
         rewriter.createOrFold<vector::BroadcastOp>(loc, blockResultType, zero);
 
-    for (int64_t i = 0, sliceWidth = std::min(opWidth, blockSize - i);
-         i < blockSize;
-         i += sliceWidth, sliceWidth = std::min(opWidth, blockSize - i)) {
-      Value slice = vector::ExtractStridedSliceOp::create(
-          rewriter, loc, block1D, i, sliceWidth, 1);
-      // TODO: replace this with non-packed ScaledTruncOp for sliceWidth == 1
-      Value scaleTrunc = amdgpu::PackedScaledTruncOp::create(
-          rewriter, loc, truncScaleResultType, slice, uniformScale, 0,
-          /*existing=*/nullptr);
-      int64_t packedWidth =
-          cast<VectorType>(scaleTrunc.getType()).getNumElements();
-      if (packedWidth != opWidth)
+    for (int64_t i = 0, outSliceWidth = std::min(opOutWidth, blockSize - i);
+         i < blockSize; i += outSliceWidth,
+                 outSliceWidth = std::min(opOutWidth, blockSize - i)) {
+      Value scaleTrunc;
+      // Case where <= 2 elements are being truncated.
+      if (outSliceWidth <= opInWidth) {
+        Value slice = vector::ExtractStridedSliceOp::create(
+            rewriter, loc, block1D, i, outSliceWidth, 1);
+        // TODO: replace this with non-packed ScaledTruncOp for sliceWidth == 1
+        scaleTrunc = amdgpu::PackedScaledTruncOp::create(
+            rewriter, loc, truncScaleResultType, slice, uniformScale, 0,
+            /*existing=*/nullptr);
+      } else {
+        scaleTrunc = vector::BroadcastOp::create(rewriter, loc,
+                                                 truncScaleResultType, zero);
+        for (int64_t j = 0,
+                     inSliceWidth = std::min(opInWidth, outSliceWidth - j);
+             j < outSliceWidth; j += opInWidth,
+                     inSliceWidth = std::min(opInWidth, outSliceWidth - j)) {
+          Value slice = vector::ExtractStridedSliceOp::create(
+              rewriter, loc, block1D, i + j, inSliceWidth, 1);
+          scaleTrunc = amdgpu::PackedScaledTruncOp::create(
+              rewriter, loc, truncScaleResultType, slice, uniformScale,
+              j / opInWidth, scaleTrunc);
+        }
+      }
+      if (outSliceWidth != opOutWidth) {
         scaleTrunc = vector::ExtractStridedSliceOp::create(
-            rewriter, loc, scaleTrunc, 0, sliceWidth, 1);
+            rewriter, loc, scaleTrunc, 0, outSliceWidth, 1);
+      }
       blockResult = vector::InsertStridedSliceOp::create(
           rewriter, loc, scaleTrunc, blockResult, i, 1);
     }
diff --git a/mlir/test/Conversion/ArithToAMDGPU/scaling-extf.mlir b/mlir/test/Conversion/ArithToAMDGPU/scaling-extf.mlir
index b98045195f8cf..a837bdb8be4fa 100644
--- a/mlir/test/Conversion/ArithToAMDGPU/scaling-extf.mlir
+++ b/mlir/test/Conversion/ArithToAMDGPU/scaling-extf.mlir
@@ -163,27 +163,23 @@ func.func @conversion_f4_f16_fallback(%in: vector<2x2xf4E2M1FN>, %scale: vector<
 // CHECK-DAG:     %[[SCALE_CAST:.+]] = vector.shape_cast %[[BCAST]]
 // CHECK-DAG:     %[[SCALE_EXT:.+]] = arith.extf %[[SCALE_CAST]]
 // CHECK-DAG:     vector.extract_strided_slice %[[IN_CAST]] {offsets = [0, 0, 0], sizes = [1, 1, 4], strides = [1, 1, 1]}
-// CHECK-NEXT:    vector.shape_cast
+// CHECK-NEXT:    %[[IN_SLICE_CAST:.+]] = vector.shape_cast
 // CHECK-NEXT:    vector.extract %[[SCALE_EXT]][0, 0, 0]
-// CHECK-NEXT:    vector.extract_strided_slice %{{.+}} {offsets = [0], sizes = [2], strides = [1]}
-// CHECK-NEXT:    amdgpu.scaled_ext_packed
-// CHECK-NEXT:    vector.insert_strided_slice %{{.+}}, %{{.+}} {offsets = [0], strides = [1]}
-// CHECK-NEXT:    vector.extract_strided_slice %{{.+}} {offsets = [2], sizes = [2], strides = [1]}
-// CHECK-NEXT:    amdgpu.scaled_ext_packed
-// CHECK-NEXT:    vector.insert_strided_slice %{{.+}}, %{{.+}} {offsets = [2], strides = [1]}
+// CHECK-NEXT:    %[[LOWHALF:.+]] = amdgpu.scaled_ext_packed %[[IN_SLICE_CAST]][0]
+// CHECK-NEXT:    vector.insert_strided_slice %[[LOWHALF]], %{{.+}} {offsets = [0], strides = [1]}
+// CHECK-NEXT:    %[[HIGHHALF:.+]] = amdgpu.scaled_ext_packed %[[IN_SLICE_CAST]][1]
+// CHECK-NEXT:    vector.insert_strided_slice %[[HIGHHALF]], %{{.+}} {offsets = [2], strides = [1]}
 // CHECK-NEXT:    vector.shape_cast
 // CHECK-NEXT:    vector.insert_strided_slice %{{.+}} {offsets = [0, 0, 0], strides = [1, 1, 1]}
 // CHECK-NEXT:    vector.extract_strided_slice %[[IN_CAST]] {offsets = [0, 1, 0], sizes = [1, 1, 4], strides = [1, 1, 1]}
 // CHECK-NEXT:    vector.shape_cast
 // CHECK-NEXT:    vector.extract %[[SCALE_EXT]][0, 1, 0]
-// CHECK-NEXT:    vector.extract_strided_slice %{{.+}} {offsets = [0], sizes = [2], strides = [1]}
 // CHECK-NEXT:    amdgpu.scaled_ext_packed
 // CHECK-NEXT:    vector.insert_strided_slice %{{.+}}, %{{.+}} {offsets = [0], strides = [1]}
-// CHECK-NEXT:    vector.extract_strided_slice %{{.+}} {offsets = [2], sizes = [2], strides = [1]}
 // CHECK-NEXT:    amdgpu.scaled_ext_packed
 // CHECK-NEXT:    vector.insert_strided_slice %{{.+}}, %{{.+}} {offsets = [2], strides = [1]}
 // CHECK-NEXT:    vector.shape_cast
-// CHECK-NEXT:    vector.insert_strided_slice %{{.+}}, %{{.+}} {offsets = [0, 1, 0], strides = [1, 1, 1]} 
+// CHECK-NEXT:    vector.insert_strided_slice %{{.+}}, %{{.+}} {offsets = [0, 1, 0], strides = [1, 1, 1]}
 func.func @conversion_broadcast(%in: vector<8x8xf8E5M2>, %scale: vector<8x2xf8E8M0FNU>) -> vector<8x8xf32> {
     %bc = vector.broadcast %scale : vector<8x2xf8E8M0FNU> to vector<4x8x2xf8E8M0FNU>
     %cast1 = vector.shape_cast %in : vector<8x8xf8E5M2> to vector<8x2x4xf8E5M2>
@@ -203,21 +199,17 @@ func.func @conversion_broadcast(%in: vector<8x8xf8E5M2>, %scale: vector<8x2xf8E8
 // CHECK-NEXT:    %[[SCALE_EXT:.+]] = arith.extf %[[SCALE_FLAT]] : vector<6xf8E8M0FNU> to vector<6xf32>
 // CHECK-NEXT:    %[[IN_SLICE_0:.+]] = vector.extract_strided_slice %arg0 {offsets = [0], sizes = [3], strides = [1]} : vector<6xf8E5M2> to vector<3xf8E5M2>
 // CHECK-NEXT:    %[[SCALE_SCALAR_0:.+]] = vector.extract %[[SCALE_EXT]][0] : f32 from vector<6xf32>
-// CHECK-NEXT:    %[[IN_CHUNK_0A:.+]] = vector.extract_strided_slice %[[IN_SLICE_0]] {offsets = [0], sizes = [2], strides = [1]} : vector<3xf8E5M2> to vector<2xf8E5M2>
-// CHECK-NEXT:    %[[PACKED_0A:.+]] = amdgpu.scaled_ext_packed %[[IN_CHUNK_0A]][0], %[[SCALE_SCALAR_0]] : vector<2xf8E5M2> to vector<2xf32>
+// CHECK-NEXT:    %[[PACKED_0A:.+]] = amdgpu.scaled_ext_packed %[[IN_SLICE_0]][0], %[[SCALE_SCALAR_0]] : vector<3xf8E5M2> to vector<2xf32>
 // CHECK-NEXT:    %[[PARTIAL_ACC_0:.+]] = vector.insert_strided_slice %[[PACKED_0A]], %[[CST_PARTIAL]] {offsets = [0], strides = [1]} : vector<2xf32> into vector<3xf32>
-// CHECK-NEXT:    %[[IN_CHUNK_0B:.+]] = vector.extract_strided_slice %[[IN_SLICE_0]] {offsets = [2], sizes = [1], strides = [1]} : vector<3xf8E5M2> to vector<1xf8E5M2>
-// CHECK-NEXT:    %[[PACKED_0B_RAW:.+]] = amdgpu.scaled_ext_packed %[[IN_CHUNK_0B]][0], %[[SCALE_SCALAR_0]] : vector<1xf8E5M2> to vector<2xf32>
+// CHECK-NEXT:    %[[PACKED_0B_RAW:.+]] = amdgpu.scaled_ext_packed %[[IN_SLICE_0]][1], %[[SCALE_SCALAR_0]] : vector<3xf8E5M2> to vector<2xf32>
 // CHECK-NEXT:    %[[PACKED_0B:.+]] = vector.extract_strided_slice %[[PACKED_0B_RAW]] {offsets = [0], sizes = [1], strides = [1]} : vector<2xf32> to vector<1xf32>
 // CHECK-NEXT:    %[[OUT_SLICE_0:.+]] = vector.insert_strided_slice %[[PACKED_0B]], %[[PARTIAL_ACC_0]] {offsets = [2], strides = [1]} : vector<1xf32> into vector<3xf32>
 // CHECK-NEXT:    %[[FINAL_ACC_A:.+]] = vector.insert_strided_slice %[[OUT_SLICE_0]], %[[CST_FINAL]] {offsets = [0], strides = [1]} : vector<3xf32> into vector<6xf32>
 // CHECK-NEXT:    %[[IN_SLICE_1:.+]] = vector.extract_strided_slice %arg0 {offsets = [3], sizes = [3], strides = [1]} : vector<6xf8E5M2> to vector<3xf8E5M2>
 // CHECK-NEXT:    %[[SCALE_SCALAR_1:.+]] = vector.extract %[[SCALE_EXT]][3] : f32 from vector<6xf32>
-// CHECK-NEXT:    %[[IN_CHUNK_1A:.+]] = vector.extract_strided_slice %[[IN_SLICE_1]] {offsets = [0], sizes = [2], strides = [1]} : vector<3xf8E5M2> to vector<2xf8E5M2>
-// CHECK-NEXT:    %[[PACKED_1A:.+]] = amdgpu.scaled_ext_packed %[[IN_CHUNK_1A]][0], %[[SCALE_SCALAR_1]] : vector<2xf8E5M2> to vector<2xf32>
+// CHECK-NEXT:    %[[PACKED_1A:.+]] = amdgpu.scaled_ext_packed %[[IN_SLICE_1]][0], %[[SCALE_SCALAR_1]] : vector<3xf8E5M2> to vector<2xf32>
 // CHECK-NEXT:    %[[PARTIAL_ACC_1:.+]] = vector.insert_strided_slice %[[PACKED_1A]], %[[CST_PARTIAL]] {offsets = [0], strides = [1]} : vector<2xf32> into vector<3xf32>
-// CHECK-NEXT:    %[[IN_CHUNK_1B:.+]] = vector.extract_strided_slice %[[IN_SLICE_1]] {offsets = [2], sizes = [1], strides = [1]} : vector<3xf8E5M2> to vector<1xf8E5M2>
-// CHECK-NEXT:    %[[PACKED_1B_RAW:.+]] = amdgpu.scaled_ext_packed %[[IN_CHUNK_1B]][0], %[[SCALE_SCALAR_1]] : vector<1xf8E5M2> to vector<2xf32>
+// CHECK-NEXT:    %[[PACKED_1B_RAW:.+]] = amdgpu.scaled_ext_packed %[[IN_SLICE_1]][1], %[[SCALE_SCALAR_1]] : vector<3xf8E5M2> to vector<2xf32>
 // CHECK-NEXT:    %[[PACKED_1B:.+]] = vector.extract_strided_slice %[[PACKED_1B_RAW]] {offsets = [0], sizes = [1], strides = [1]} : vector<2xf32> to vector<1xf32>
 // CHECK-NEXT:    %[[OUT_SLICE_1:.+]] = vector.insert_strided_slice %[[PACKED_1B]], %[[PARTIAL_ACC_1]] {offsets = [2], strides = [1]} : vector<1xf32> into vector<3xf32>
 // CHECK-NEXT:    %[[RESULT:.+]] = vector.insert_strided_slice %[[OUT_SLICE_1]], %[[FINAL_ACC_A]] {offsets = [3], strides = [1]} : vector<3xf32> into vector<6xf32>
@@ -236,11 +228,9 @@ func.func @conversion_broadcast_odd(%in: vector<6xf8E5M2>, %scale: vector<2xf8E8
 // CHECK-DAG:     %[[SCALE_SPLAT:.+]] = vector.broadcast %arg1 : f8E8M0FNU to vector<4xf8E8M0FNU>
 // CHECK-DAG:     %[[SCALE_EXTF:.+]] = arith.extf %[[SCALE_SPLAT]] : vector<4xf8E8M0FNU> to vector<4xf32>
 // CHECK-DAG:     %[[SCALE_SCALAR:.+]] = vector.extract %[[SCALE_EXTF]][0] : f32 from vector<4xf32>
-// CHECK:         %[[IN_CHUNK0:.+]] = vector.extract_strided_slice %arg0 {offsets = [0], sizes = [2], strides = [1]} : vector<4xf8E5M2> to vector<2xf8E5M2>
-// CHECK-NEXT:    %[[OUT_CHUNK0:.+]] = amdgpu.scaled_ext_packed %[[IN_CHUNK0]][0], %[[SCALE_SCALAR]] : vector<2xf8E5M2> to vector<2xf32>
+// CHECK:         %[[OUT_CHUNK0:.+]] = amdgpu.scaled_ext_packed %arg0[0], %[[SCALE_SCALAR]] : vector<4xf8E5M2> to vector<2xf32>
 // CHECK-NEXT:    %[[ACCUM_A:.+]] = vector.insert_strided_slice %[[OUT_CHUNK0]], %[[CST]] {offsets = [0], strides = [1]} : vector<2xf32> into vector<4xf32>
-// CHECK-NEXT:    %[[IN_CHUNK1:.+]] = vector.extract_strided_slice %arg0 {offsets = [2], sizes = [2], strides = [1]} : vector<4xf8E5M2> to vector<2xf8E5M2>
-// CHECK-NEXT:    %[[OUT_CHUNK1:.+]] = amdgpu.scaled_ext_packed %[[IN_CHUNK1]][0], %[[SCALE_SCALAR]] : vector<2xf8E5M2> to vector<2xf32>
+// CHECK-NEXT:    %[[OUT_CHUNK1:.+]] = amdgpu.scaled_ext_packed %arg0[1], %[[SCALE_SCALAR]] : vector<4xf8E5M2> to vector<2xf32>
 // CHECK-NEXT:    %[[FINAL_RESULT:.+]] = vector.insert_strided_slice %[[OUT_CHUNK1]], %[[ACCUM_A]] {offsets = [2], strides = [1]} : vector<2xf32> into vector<4xf32>
 // CHECK-NEXT:    return %[[FINAL_RESULT]] : vector<4xf32>
 func.func @conversion_broadcast(%in: vector<4xf8E5M2>, %scale: f8E8M0FNU) -> vector<4xf32> {
@@ -261,3 +251,15 @@ func.func @conversion_scalar(%in: f8E5M2, %scale: f8E8M0FNU) -> f32 {
     %ext = arith.scaling_extf %in, %scale : f8E5M2, f8E8M0FNU to f32
     return %ext : f32
 }
+
+// -----
+
+// CHECK-LABEL: @long_fp4_broadcast
+// CHECK-COUNT-4: amdgpu.scaled_ext_packed %{{.+}}[3]
+// CHECK-NOT: amdgpu.scaled_ext_packed
+// CHECK: return
+func.func @long_fp4_broadcast(%in: vector<32xf4E2M1FN>, %scale: f32) -> vector<32xf32> {
+    %splat = vector.broadcast %scale : f32 to vector<32xf32>
+    %ext = arith.scaling_extf %in, %splat : vector<32xf4E2M1FN>, vector<32xf32> to vector<32xf32>
+    return %ext : vector<32xf32>
+}
diff --git a/mlir/test/Conversion/ArithToAMDGPU/scaling-truncf.mlir b/mlir/test/Conversion/ArithToAMDGPU/scaling-truncf.mlir
index 488e75cbb1843..6d6e1e28d2c2c 100644
--- a/mlir/test/Conversion/ArithToAMDGPU/scaling-truncf.mlir
+++ b/mlir/test/Conversion/ArithToAMDGPU/scaling-truncf.mlir
@@ -88,28 +88,20 @@ func.func @conversion_f4_fallback(%in: vector<2x2xf32>, %scale: vector<2x2xf8E8M
 // CHECK-NEXT:    vector.shape_cast
 // CHECK-NEXT:    vector.extract %[[SCALE_EXT]][0, 0, 0]
 // CHECK-NEXT:    vector.extract_strided_slice %{{.+}} {offsets = [0], sizes = [2], strides = [1]}
-// CHECK-NEXT:    amdgpu.packed_scaled_trunc
-// CHECK-NEXT:    vector.extract_strided_slice %{{.+}} {offsets = [0], sizes = [2], strides = [1]}
-// CHECK-NEXT:    vector.insert_strided_slice %{{.+}}, %{{.+}} {offsets = [0], strides = [1]}
+// CHECK-NEXT:    %[[P1:.+]] = amdgpu.packed_scaled_trunc
 // CHECK-NEXT:    vector.extract_strided_slice %{{.+}} {offsets = [2], sizes = [2], strides = [1]}
-// CHECK-NEXT:    amdgpu.packed_scaled_trunc
-// CHECK-NEXT:    vector.extract_strided_slice
-// CHECK-NEXT:    vector.insert_strided_slice %{{.+}}, %{{.+}} {offsets = [2], strides = [1]}
-// CHECK-NEXT:    vector.shape_cast
-// CHECK-NEXT:    vector.insert_strided_slice %{{.+}} {offsets = [0, 0, 0], strides = [1, 1, 1]}
+// CHECK-NEXT:    %[[P2:.+]] = amdgpu.packed_scaled_trunc {{.*}} into %[[P1]][1]
+// CHECK-NEXT:    %[[P2_CAST:.+]] = vector.shape_cast %[[P2]] : vector<4xf8E5M2> to vector<1x1x4xf8E5M2>
+// CHECK-NEXT:    vector.insert_strided_slice %[[P2_CAST]], %{{.+}} {offsets = [0, 0, 0], strides = [1, 1, 1]}
 // CHECK-NEXT:    vector.extract_strided_slice %[[IN_CAST]] {offsets = [0, 1, 0], sizes = [1, 1, 4], strides = [1, 1, 1]}
 // CHECK-NEXT:    vector.shape_cast
 // CHECK-NEXT:    vector.extract %[[SCALE_EXT]][0, 1, 0]
 // CHECK-NEXT:    vector.extract_strided_slice %{{.+}} {offsets = [0], sizes = [2], strides = [1]}
 // CHECK-NEXT:    amdgpu.packed_scaled_trunc
-// CHECK-NEXT:    vector.extract_strided_slice %{{.+}} {offsets = [0], sizes = [2], strides = [1]}
-// CHECK-NEXT:    vector.insert_strided_slice %{{.+}}, %{{.+}} {offsets = [0], strides = [1]}
 // CHECK-NEXT:    vector.extract_strided_slice %{{.+}} {offsets = [2], sizes = [2], strides = [1]}
 // CHECK-NEXT:    amdgpu.packed_scaled_trunc
-// CHECK-NEXT:    vector.extract_strided_slice %{{.+}} {offsets = [0], sizes = [2], strides = [1]}
-// CHECK-N...
[truncated]

@llvmbot
Copy link
Member

llvmbot commented Jul 24, 2025

@llvm/pr-subscribers-mlir

Author: Krzysztof Drewniak (krzysz00)

Changes

The current arith-to-amdgpu patterns for scaling_extf and scaling_truncf don't take full advantage of the native packing ability of the intrinsics being targetted. Scaling extension takes the location of the two elements to be extended as a constant argument (byte for fp4, half for fp8), and scaling truncation takes a 32-bit input register and a byte or half to write the truncated values to.

Not using these features would cause excess unneeded register pressure. This PR resolves the inefficiency.

It also adds a test for the expected usecase of extending or truncateting a block of 32 values to/from fp4 with a uniform scale to ensure that this usage has a minimal amount of vector shuffling.


Patch is 28.80 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/150342.diff

3 Files Affected:

  • (modified) mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp (+62-38)
  • (modified) mlir/test/Conversion/ArithToAMDGPU/scaling-extf.mlir (+24-22)
  • (modified) mlir/test/Conversion/ArithToAMDGPU/scaling-truncf.mlir (+27-33)
diff --git a/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp b/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
index 8c68b57877c35..4cf80167b20c2 100644
--- a/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
+++ b/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
@@ -449,7 +449,8 @@ LogicalResult
 ScalingExtFRewritePattern::matchAndRewrite(arith::ScalingExtFOp op,
                                            PatternRewriter &rewriter) const {
   Location loc = op.getLoc();
-  constexpr int64_t opWidth = 2;
+  constexpr int64_t opOutWidth = 2;
+  constexpr int64_t opInWidth = 8;
 
   Value in = op.getIn();
   Value scale = op.getScale();
@@ -473,7 +474,7 @@ ScalingExtFRewritePattern::matchAndRewrite(arith::ScalingExtFOp op,
   else if (scaleType.getIntOrFloatBitWidth() > 32)
     scale = arith::TruncFOp::create(rewriter, loc, scaleF32Type, scale);
 
-  VectorType extScaleResultType = VectorType::get(opWidth, outType);
+  VectorType extScaleResultType = VectorType::get(opOutWidth, outType);
 
   if (!outVecType) {
     Value inCast = vector::BroadcastOp::create(rewriter, loc,
@@ -487,10 +488,11 @@ ScalingExtFRewritePattern::matchAndRewrite(arith::ScalingExtFOp op,
 
   VectorType inVecType = cast<VectorType>(in.getType());
   Value origScale = getOriginalVectorValue(op.getScale());
+  VectorType origScaleVecType = dyn_cast<VectorType>(origScale.getType());
 
   ArrayRef<int64_t> inShape = inVecType.getShape();
   SmallVector<int64_t> originalScaleShape;
-  if (auto origScaleVecType = dyn_cast<VectorType>(origScale.getType()))
+  if (origScaleVecType)
     llvm::append_range(originalScaleShape, origScaleVecType.getShape());
 
   originalScaleShape.insert(originalScaleShape.end(),
@@ -524,19 +526,26 @@ ScalingExtFRewritePattern::matchAndRewrite(arith::ScalingExtFOp op,
     Value blockResult =
         rewriter.createOrFold<vector::BroadcastOp>(loc, blockResultType, zero);
 
-    for (int64_t i = 0, sliceWidth = std::min(opWidth, blockSize - i);
+    for (int64_t i = 0, inSliceWidth = std::min(opInWidth, blockSize - i);
          i < blockSize;
-         i += sliceWidth, sliceWidth = std::min(opWidth, blockSize - i)) {
-      Value slice = vector::ExtractStridedSliceOp::create(
-          rewriter, loc, block1D, i, sliceWidth, 1);
-      // TODO: replace this with non-packed ScaledExtOp for sliceWidth == 1
-      Value scaleExt = amdgpu::ScaledExtPackedOp::create(
-          rewriter, loc, extScaleResultType, slice, uniformScale, 0);
-      if (sliceWidth != opWidth)
-        scaleExt = vector::ExtractStridedSliceOp::create(
-            rewriter, loc, scaleExt, 0, sliceWidth, 1);
-      blockResult = vector::InsertStridedSliceOp::create(
-          rewriter, loc, scaleExt, blockResult, i, 1);
+         i += inSliceWidth, inSliceWidth = std::min(opInWidth, blockSize - i)) {
+      Value inSlice = vector::ExtractStridedSliceOp::create(
+          rewriter, loc, block1D, i, inSliceWidth, 1);
+      for (int64_t j = 0,
+                   outSliceWidth = std::min(opOutWidth, inSliceWidth - j);
+           j < inSliceWidth; j += outSliceWidth,
+                   outSliceWidth = std::min(opOutWidth, inSliceWidth - j)) {
+        // TODO: replace this with non-packed ScaledExtOp for sliceWidth == 1
+        Value scaleExt = amdgpu::ScaledExtPackedOp::create(
+            rewriter, loc, extScaleResultType, inSlice, uniformScale,
+            j / opOutWidth);
+        if (outSliceWidth < opOutWidth) {
+          scaleExt = vector::ExtractStridedSliceOp::create(
+              rewriter, loc, scaleExt, 0, outSliceWidth, 1);
+        }
+        blockResult = vector::InsertStridedSliceOp::create(
+            rewriter, loc, scaleExt, blockResult, i + j, 1);
+      }
     }
 
     VectorType resultType = VectorType::get(ratio, outType);
@@ -555,7 +564,7 @@ LogicalResult
 ScalingTruncFRewritePattern::matchAndRewrite(arith::ScalingTruncFOp op,
                                              PatternRewriter &rewriter) const {
   Location loc = op.getLoc();
-  constexpr int64_t opWidth = 2;
+  constexpr int64_t opInWidth = 2;
 
   Value in = op.getIn();
   Value scale = op.getScale();
@@ -568,7 +577,6 @@ ScalingTruncFRewritePattern::matchAndRewrite(arith::ScalingTruncFOp op,
 
   VectorType outVecType = dyn_cast<VectorType>(out.getType());
   VectorType scaleVecType = dyn_cast<VectorType>(scale.getType());
-
   if (outVecType && outVecType.isScalable())
     return failure();
 
@@ -581,8 +589,8 @@ ScalingTruncFRewritePattern::matchAndRewrite(arith::ScalingTruncFOp op,
 
   Value zero = arith::ConstantOp::create(rewriter, loc, outType,
                                          rewriter.getFloatAttr(outType, 0.0));
-  unsigned numPackedElem = 32 / outType.getIntOrFloatBitWidth();
-  VectorType truncScaleResultType = VectorType::get(numPackedElem, outType);
+  int64_t opOutWidth = 32 / outType.getIntOrFloatBitWidth();
+  VectorType truncScaleResultType = VectorType::get(opOutWidth, outType);
 
   if (!outVecType) {
     Type inVecType = VectorType::get(1, inType);
@@ -598,16 +606,16 @@ ScalingTruncFRewritePattern::matchAndRewrite(arith::ScalingTruncFOp op,
 
   VectorType inVecType = cast<VectorType>(in.getType());
   Value origScale = getOriginalVectorValue(op.getScale());
+  VectorType origScaleVecType = dyn_cast<VectorType>(origScale.getType());
 
   ArrayRef<int64_t> inShape = inVecType.getShape();
-  SmallVector<int64_t> originalScaleShape;
-  if (auto origScaleVecType = dyn_cast<VectorType>(origScale.getType()))
-    llvm::append_range(originalScaleShape, origScaleVecType.getShape());
+  SmallVector<int64_t> scaleShape;
+  if (origScaleVecType)
+    llvm::append_range(scaleShape, origScaleVecType.getShape());
 
-  originalScaleShape.insert(originalScaleShape.end(),
-                            inShape.size() - originalScaleShape.size(), 1);
+  scaleShape.insert(scaleShape.end(), inShape.size() - scaleShape.size(), 1);
 
-  auto maybeRatio = computeShapeRatio(inShape, originalScaleShape);
+  auto maybeRatio = computeShapeRatio(inShape, scaleShape);
   assert(maybeRatio &&
          "failed to derive block size from broadcast or splat operation");
 
@@ -633,20 +641,36 @@ ScalingTruncFRewritePattern::matchAndRewrite(arith::ScalingTruncFOp op,
     Value blockResult =
         rewriter.createOrFold<vector::BroadcastOp>(loc, blockResultType, zero);
 
-    for (int64_t i = 0, sliceWidth = std::min(opWidth, blockSize - i);
-         i < blockSize;
-         i += sliceWidth, sliceWidth = std::min(opWidth, blockSize - i)) {
-      Value slice = vector::ExtractStridedSliceOp::create(
-          rewriter, loc, block1D, i, sliceWidth, 1);
-      // TODO: replace this with non-packed ScaledTruncOp for sliceWidth == 1
-      Value scaleTrunc = amdgpu::PackedScaledTruncOp::create(
-          rewriter, loc, truncScaleResultType, slice, uniformScale, 0,
-          /*existing=*/nullptr);
-      int64_t packedWidth =
-          cast<VectorType>(scaleTrunc.getType()).getNumElements();
-      if (packedWidth != opWidth)
+    for (int64_t i = 0, outSliceWidth = std::min(opOutWidth, blockSize - i);
+         i < blockSize; i += outSliceWidth,
+                 outSliceWidth = std::min(opOutWidth, blockSize - i)) {
+      Value scaleTrunc;
+      // Case where <= 2 elements are being truncated.
+      if (outSliceWidth <= opInWidth) {
+        Value slice = vector::ExtractStridedSliceOp::create(
+            rewriter, loc, block1D, i, outSliceWidth, 1);
+        // TODO: replace this with non-packed ScaledTruncOp for sliceWidth == 1
+        scaleTrunc = amdgpu::PackedScaledTruncOp::create(
+            rewriter, loc, truncScaleResultType, slice, uniformScale, 0,
+            /*existing=*/nullptr);
+      } else {
+        scaleTrunc = vector::BroadcastOp::create(rewriter, loc,
+                                                 truncScaleResultType, zero);
+        for (int64_t j = 0,
+                     inSliceWidth = std::min(opInWidth, outSliceWidth - j);
+             j < outSliceWidth; j += opInWidth,
+                     inSliceWidth = std::min(opInWidth, outSliceWidth - j)) {
+          Value slice = vector::ExtractStridedSliceOp::create(
+              rewriter, loc, block1D, i + j, inSliceWidth, 1);
+          scaleTrunc = amdgpu::PackedScaledTruncOp::create(
+              rewriter, loc, truncScaleResultType, slice, uniformScale,
+              j / opInWidth, scaleTrunc);
+        }
+      }
+      if (outSliceWidth != opOutWidth) {
         scaleTrunc = vector::ExtractStridedSliceOp::create(
-            rewriter, loc, scaleTrunc, 0, sliceWidth, 1);
+            rewriter, loc, scaleTrunc, 0, outSliceWidth, 1);
+      }
       blockResult = vector::InsertStridedSliceOp::create(
           rewriter, loc, scaleTrunc, blockResult, i, 1);
     }
diff --git a/mlir/test/Conversion/ArithToAMDGPU/scaling-extf.mlir b/mlir/test/Conversion/ArithToAMDGPU/scaling-extf.mlir
index b98045195f8cf..a837bdb8be4fa 100644
--- a/mlir/test/Conversion/ArithToAMDGPU/scaling-extf.mlir
+++ b/mlir/test/Conversion/ArithToAMDGPU/scaling-extf.mlir
@@ -163,27 +163,23 @@ func.func @conversion_f4_f16_fallback(%in: vector<2x2xf4E2M1FN>, %scale: vector<
 // CHECK-DAG:     %[[SCALE_CAST:.+]] = vector.shape_cast %[[BCAST]]
 // CHECK-DAG:     %[[SCALE_EXT:.+]] = arith.extf %[[SCALE_CAST]]
 // CHECK-DAG:     vector.extract_strided_slice %[[IN_CAST]] {offsets = [0, 0, 0], sizes = [1, 1, 4], strides = [1, 1, 1]}
-// CHECK-NEXT:    vector.shape_cast
+// CHECK-NEXT:    %[[IN_SLICE_CAST:.+]] = vector.shape_cast
 // CHECK-NEXT:    vector.extract %[[SCALE_EXT]][0, 0, 0]
-// CHECK-NEXT:    vector.extract_strided_slice %{{.+}} {offsets = [0], sizes = [2], strides = [1]}
-// CHECK-NEXT:    amdgpu.scaled_ext_packed
-// CHECK-NEXT:    vector.insert_strided_slice %{{.+}}, %{{.+}} {offsets = [0], strides = [1]}
-// CHECK-NEXT:    vector.extract_strided_slice %{{.+}} {offsets = [2], sizes = [2], strides = [1]}
-// CHECK-NEXT:    amdgpu.scaled_ext_packed
-// CHECK-NEXT:    vector.insert_strided_slice %{{.+}}, %{{.+}} {offsets = [2], strides = [1]}
+// CHECK-NEXT:    %[[LOWHALF:.+]] = amdgpu.scaled_ext_packed %[[IN_SLICE_CAST]][0]
+// CHECK-NEXT:    vector.insert_strided_slice %[[LOWHALF]], %{{.+}} {offsets = [0], strides = [1]}
+// CHECK-NEXT:    %[[HIGHHALF:.+]] = amdgpu.scaled_ext_packed %[[IN_SLICE_CAST]][1]
+// CHECK-NEXT:    vector.insert_strided_slice %[[HIGHHALF]], %{{.+}} {offsets = [2], strides = [1]}
 // CHECK-NEXT:    vector.shape_cast
 // CHECK-NEXT:    vector.insert_strided_slice %{{.+}} {offsets = [0, 0, 0], strides = [1, 1, 1]}
 // CHECK-NEXT:    vector.extract_strided_slice %[[IN_CAST]] {offsets = [0, 1, 0], sizes = [1, 1, 4], strides = [1, 1, 1]}
 // CHECK-NEXT:    vector.shape_cast
 // CHECK-NEXT:    vector.extract %[[SCALE_EXT]][0, 1, 0]
-// CHECK-NEXT:    vector.extract_strided_slice %{{.+}} {offsets = [0], sizes = [2], strides = [1]}
 // CHECK-NEXT:    amdgpu.scaled_ext_packed
 // CHECK-NEXT:    vector.insert_strided_slice %{{.+}}, %{{.+}} {offsets = [0], strides = [1]}
-// CHECK-NEXT:    vector.extract_strided_slice %{{.+}} {offsets = [2], sizes = [2], strides = [1]}
 // CHECK-NEXT:    amdgpu.scaled_ext_packed
 // CHECK-NEXT:    vector.insert_strided_slice %{{.+}}, %{{.+}} {offsets = [2], strides = [1]}
 // CHECK-NEXT:    vector.shape_cast
-// CHECK-NEXT:    vector.insert_strided_slice %{{.+}}, %{{.+}} {offsets = [0, 1, 0], strides = [1, 1, 1]} 
+// CHECK-NEXT:    vector.insert_strided_slice %{{.+}}, %{{.+}} {offsets = [0, 1, 0], strides = [1, 1, 1]}
 func.func @conversion_broadcast(%in: vector<8x8xf8E5M2>, %scale: vector<8x2xf8E8M0FNU>) -> vector<8x8xf32> {
     %bc = vector.broadcast %scale : vector<8x2xf8E8M0FNU> to vector<4x8x2xf8E8M0FNU>
     %cast1 = vector.shape_cast %in : vector<8x8xf8E5M2> to vector<8x2x4xf8E5M2>
@@ -203,21 +199,17 @@ func.func @conversion_broadcast(%in: vector<8x8xf8E5M2>, %scale: vector<8x2xf8E8
 // CHECK-NEXT:    %[[SCALE_EXT:.+]] = arith.extf %[[SCALE_FLAT]] : vector<6xf8E8M0FNU> to vector<6xf32>
 // CHECK-NEXT:    %[[IN_SLICE_0:.+]] = vector.extract_strided_slice %arg0 {offsets = [0], sizes = [3], strides = [1]} : vector<6xf8E5M2> to vector<3xf8E5M2>
 // CHECK-NEXT:    %[[SCALE_SCALAR_0:.+]] = vector.extract %[[SCALE_EXT]][0] : f32 from vector<6xf32>
-// CHECK-NEXT:    %[[IN_CHUNK_0A:.+]] = vector.extract_strided_slice %[[IN_SLICE_0]] {offsets = [0], sizes = [2], strides = [1]} : vector<3xf8E5M2> to vector<2xf8E5M2>
-// CHECK-NEXT:    %[[PACKED_0A:.+]] = amdgpu.scaled_ext_packed %[[IN_CHUNK_0A]][0], %[[SCALE_SCALAR_0]] : vector<2xf8E5M2> to vector<2xf32>
+// CHECK-NEXT:    %[[PACKED_0A:.+]] = amdgpu.scaled_ext_packed %[[IN_SLICE_0]][0], %[[SCALE_SCALAR_0]] : vector<3xf8E5M2> to vector<2xf32>
 // CHECK-NEXT:    %[[PARTIAL_ACC_0:.+]] = vector.insert_strided_slice %[[PACKED_0A]], %[[CST_PARTIAL]] {offsets = [0], strides = [1]} : vector<2xf32> into vector<3xf32>
-// CHECK-NEXT:    %[[IN_CHUNK_0B:.+]] = vector.extract_strided_slice %[[IN_SLICE_0]] {offsets = [2], sizes = [1], strides = [1]} : vector<3xf8E5M2> to vector<1xf8E5M2>
-// CHECK-NEXT:    %[[PACKED_0B_RAW:.+]] = amdgpu.scaled_ext_packed %[[IN_CHUNK_0B]][0], %[[SCALE_SCALAR_0]] : vector<1xf8E5M2> to vector<2xf32>
+// CHECK-NEXT:    %[[PACKED_0B_RAW:.+]] = amdgpu.scaled_ext_packed %[[IN_SLICE_0]][1], %[[SCALE_SCALAR_0]] : vector<3xf8E5M2> to vector<2xf32>
 // CHECK-NEXT:    %[[PACKED_0B:.+]] = vector.extract_strided_slice %[[PACKED_0B_RAW]] {offsets = [0], sizes = [1], strides = [1]} : vector<2xf32> to vector<1xf32>
 // CHECK-NEXT:    %[[OUT_SLICE_0:.+]] = vector.insert_strided_slice %[[PACKED_0B]], %[[PARTIAL_ACC_0]] {offsets = [2], strides = [1]} : vector<1xf32> into vector<3xf32>
 // CHECK-NEXT:    %[[FINAL_ACC_A:.+]] = vector.insert_strided_slice %[[OUT_SLICE_0]], %[[CST_FINAL]] {offsets = [0], strides = [1]} : vector<3xf32> into vector<6xf32>
 // CHECK-NEXT:    %[[IN_SLICE_1:.+]] = vector.extract_strided_slice %arg0 {offsets = [3], sizes = [3], strides = [1]} : vector<6xf8E5M2> to vector<3xf8E5M2>
 // CHECK-NEXT:    %[[SCALE_SCALAR_1:.+]] = vector.extract %[[SCALE_EXT]][3] : f32 from vector<6xf32>
-// CHECK-NEXT:    %[[IN_CHUNK_1A:.+]] = vector.extract_strided_slice %[[IN_SLICE_1]] {offsets = [0], sizes = [2], strides = [1]} : vector<3xf8E5M2> to vector<2xf8E5M2>
-// CHECK-NEXT:    %[[PACKED_1A:.+]] = amdgpu.scaled_ext_packed %[[IN_CHUNK_1A]][0], %[[SCALE_SCALAR_1]] : vector<2xf8E5M2> to vector<2xf32>
+// CHECK-NEXT:    %[[PACKED_1A:.+]] = amdgpu.scaled_ext_packed %[[IN_SLICE_1]][0], %[[SCALE_SCALAR_1]] : vector<3xf8E5M2> to vector<2xf32>
 // CHECK-NEXT:    %[[PARTIAL_ACC_1:.+]] = vector.insert_strided_slice %[[PACKED_1A]], %[[CST_PARTIAL]] {offsets = [0], strides = [1]} : vector<2xf32> into vector<3xf32>
-// CHECK-NEXT:    %[[IN_CHUNK_1B:.+]] = vector.extract_strided_slice %[[IN_SLICE_1]] {offsets = [2], sizes = [1], strides = [1]} : vector<3xf8E5M2> to vector<1xf8E5M2>
-// CHECK-NEXT:    %[[PACKED_1B_RAW:.+]] = amdgpu.scaled_ext_packed %[[IN_CHUNK_1B]][0], %[[SCALE_SCALAR_1]] : vector<1xf8E5M2> to vector<2xf32>
+// CHECK-NEXT:    %[[PACKED_1B_RAW:.+]] = amdgpu.scaled_ext_packed %[[IN_SLICE_1]][1], %[[SCALE_SCALAR_1]] : vector<3xf8E5M2> to vector<2xf32>
 // CHECK-NEXT:    %[[PACKED_1B:.+]] = vector.extract_strided_slice %[[PACKED_1B_RAW]] {offsets = [0], sizes = [1], strides = [1]} : vector<2xf32> to vector<1xf32>
 // CHECK-NEXT:    %[[OUT_SLICE_1:.+]] = vector.insert_strided_slice %[[PACKED_1B]], %[[PARTIAL_ACC_1]] {offsets = [2], strides = [1]} : vector<1xf32> into vector<3xf32>
 // CHECK-NEXT:    %[[RESULT:.+]] = vector.insert_strided_slice %[[OUT_SLICE_1]], %[[FINAL_ACC_A]] {offsets = [3], strides = [1]} : vector<3xf32> into vector<6xf32>
@@ -236,11 +228,9 @@ func.func @conversion_broadcast_odd(%in: vector<6xf8E5M2>, %scale: vector<2xf8E8
 // CHECK-DAG:     %[[SCALE_SPLAT:.+]] = vector.broadcast %arg1 : f8E8M0FNU to vector<4xf8E8M0FNU>
 // CHECK-DAG:     %[[SCALE_EXTF:.+]] = arith.extf %[[SCALE_SPLAT]] : vector<4xf8E8M0FNU> to vector<4xf32>
 // CHECK-DAG:     %[[SCALE_SCALAR:.+]] = vector.extract %[[SCALE_EXTF]][0] : f32 from vector<4xf32>
-// CHECK:         %[[IN_CHUNK0:.+]] = vector.extract_strided_slice %arg0 {offsets = [0], sizes = [2], strides = [1]} : vector<4xf8E5M2> to vector<2xf8E5M2>
-// CHECK-NEXT:    %[[OUT_CHUNK0:.+]] = amdgpu.scaled_ext_packed %[[IN_CHUNK0]][0], %[[SCALE_SCALAR]] : vector<2xf8E5M2> to vector<2xf32>
+// CHECK:         %[[OUT_CHUNK0:.+]] = amdgpu.scaled_ext_packed %arg0[0], %[[SCALE_SCALAR]] : vector<4xf8E5M2> to vector<2xf32>
 // CHECK-NEXT:    %[[ACCUM_A:.+]] = vector.insert_strided_slice %[[OUT_CHUNK0]], %[[CST]] {offsets = [0], strides = [1]} : vector<2xf32> into vector<4xf32>
-// CHECK-NEXT:    %[[IN_CHUNK1:.+]] = vector.extract_strided_slice %arg0 {offsets = [2], sizes = [2], strides = [1]} : vector<4xf8E5M2> to vector<2xf8E5M2>
-// CHECK-NEXT:    %[[OUT_CHUNK1:.+]] = amdgpu.scaled_ext_packed %[[IN_CHUNK1]][0], %[[SCALE_SCALAR]] : vector<2xf8E5M2> to vector<2xf32>
+// CHECK-NEXT:    %[[OUT_CHUNK1:.+]] = amdgpu.scaled_ext_packed %arg0[1], %[[SCALE_SCALAR]] : vector<4xf8E5M2> to vector<2xf32>
 // CHECK-NEXT:    %[[FINAL_RESULT:.+]] = vector.insert_strided_slice %[[OUT_CHUNK1]], %[[ACCUM_A]] {offsets = [2], strides = [1]} : vector<2xf32> into vector<4xf32>
 // CHECK-NEXT:    return %[[FINAL_RESULT]] : vector<4xf32>
 func.func @conversion_broadcast(%in: vector<4xf8E5M2>, %scale: f8E8M0FNU) -> vector<4xf32> {
@@ -261,3 +251,15 @@ func.func @conversion_scalar(%in: f8E5M2, %scale: f8E8M0FNU) -> f32 {
     %ext = arith.scaling_extf %in, %scale : f8E5M2, f8E8M0FNU to f32
     return %ext : f32
 }
+
+// -----
+
+// CHECK-LABEL: @long_fp4_broadcast
+// CHECK-COUNT-4: amdgpu.scaled_ext_packed %{{.+}}[3]
+// CHECK-NOT: amdgpu.scaled_ext_packed
+// CHECK: return
+func.func @long_fp4_broadcast(%in: vector<32xf4E2M1FN>, %scale: f32) -> vector<32xf32> {
+    %splat = vector.broadcast %scale : f32 to vector<32xf32>
+    %ext = arith.scaling_extf %in, %splat : vector<32xf4E2M1FN>, vector<32xf32> to vector<32xf32>
+    return %ext : vector<32xf32>
+}
diff --git a/mlir/test/Conversion/ArithToAMDGPU/scaling-truncf.mlir b/mlir/test/Conversion/ArithToAMDGPU/scaling-truncf.mlir
index 488e75cbb1843..6d6e1e28d2c2c 100644
--- a/mlir/test/Conversion/ArithToAMDGPU/scaling-truncf.mlir
+++ b/mlir/test/Conversion/ArithToAMDGPU/scaling-truncf.mlir
@@ -88,28 +88,20 @@ func.func @conversion_f4_fallback(%in: vector<2x2xf32>, %scale: vector<2x2xf8E8M
 // CHECK-NEXT:    vector.shape_cast
 // CHECK-NEXT:    vector.extract %[[SCALE_EXT]][0, 0, 0]
 // CHECK-NEXT:    vector.extract_strided_slice %{{.+}} {offsets = [0], sizes = [2], strides = [1]}
-// CHECK-NEXT:    amdgpu.packed_scaled_trunc
-// CHECK-NEXT:    vector.extract_strided_slice %{{.+}} {offsets = [0], sizes = [2], strides = [1]}
-// CHECK-NEXT:    vector.insert_strided_slice %{{.+}}, %{{.+}} {offsets = [0], strides = [1]}
+// CHECK-NEXT:    %[[P1:.+]] = amdgpu.packed_scaled_trunc
 // CHECK-NEXT:    vector.extract_strided_slice %{{.+}} {offsets = [2], sizes = [2], strides = [1]}
-// CHECK-NEXT:    amdgpu.packed_scaled_trunc
-// CHECK-NEXT:    vector.extract_strided_slice
-// CHECK-NEXT:    vector.insert_strided_slice %{{.+}}, %{{.+}} {offsets = [2], strides = [1]}
-// CHECK-NEXT:    vector.shape_cast
-// CHECK-NEXT:    vector.insert_strided_slice %{{.+}} {offsets = [0, 0, 0], strides = [1, 1, 1]}
+// CHECK-NEXT:    %[[P2:.+]] = amdgpu.packed_scaled_trunc {{.*}} into %[[P1]][1]
+// CHECK-NEXT:    %[[P2_CAST:.+]] = vector.shape_cast %[[P2]] : vector<4xf8E5M2> to vector<1x1x4xf8E5M2>
+// CHECK-NEXT:    vector.insert_strided_slice %[[P2_CAST]], %{{.+}} {offsets = [0, 0, 0], strides = [1, 1, 1]}
 // CHECK-NEXT:    vector.extract_strided_slice %[[IN_CAST]] {offsets = [0, 1, 0], sizes = [1, 1, 4], strides = [1, 1, 1]}
 // CHECK-NEXT:    vector.shape_cast
 // CHECK-NEXT:    vector.extract %[[SCALE_EXT]][0, 1, 0]
 // CHECK-NEXT:    vector.extract_strided_slice %{{.+}} {offsets = [0], sizes = [2], strides = [1]}
 // CHECK-NEXT:    amdgpu.packed_scaled_trunc
-// CHECK-NEXT:    vector.extract_strided_slice %{{.+}} {offsets = [0], sizes = [2], strides = [1]}
-// CHECK-NEXT:    vector.insert_strided_slice %{{.+}}, %{{.+}} {offsets = [0], strides = [1]}
 // CHECK-NEXT:    vector.extract_strided_slice %{{.+}} {offsets = [2], sizes = [2], strides = [1]}
 // CHECK-NEXT:    amdgpu.packed_scaled_trunc
-// CHECK-NEXT:    vector.extract_strided_slice %{{.+}} {offsets = [0], sizes = [2], strides = [1]}
-// CHECK-N...
[truncated]

@krzysz00 krzysz00 requested a review from tgymnich July 24, 2025 16:34
@krzysz00 krzysz00 merged commit a4dd51d into llvm:main Jul 24, 2025
9 checks passed
mahesh-attarde pushed a commit to mahesh-attarde/llvm-project that referenced this pull request Jul 28, 2025
The current arith-to-amdgpu patterns for scaling_extf and scaling_truncf
don't take full advantage of the native packing ability of the
intrinsics being targetted. Scaling extension takes the location of the
two elements to be extended as a constant argument (byte for fp4, half
for fp8), and scaling truncation takes a 32-bit input register and a
byte or half to write the truncated values to.

Not using these features would cause excess unneeded register pressure.
This PR resolves the inefficiency.

It also adds a test for the expected usecase of extending or
truncateting a block of 32 values to/from fp4 with a uniform scale to
ensure that this usage has a minimal amount of vector shuffling.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants