-
Notifications
You must be signed in to change notification settings - Fork 14.6k
[mlir][ArithToAMDGPU] Use native packing support #150342
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][ArithToAMDGPU] Use native packing support #150342
Conversation
The current arith-to-amdgpu patterns for scaling_extf and scaling_truncf don't take full advantage of the native packing ability of the intrinsics being targetted. Scaling extension takes the location of the two elements to be extended as a constant argument (byte for fp4, half for fp8), and scaling truncation takes a 32-bit input register and a byte or half to write the truncated values to. Not using these features would cause excess unneeded register pressure. This PR resolves the inefficiency. It also adds a test for the expected usecase of extending or truncateting a block of 32 values to/from fp4 with a uniform scale to ensure that this usage has a minimal amount of vector shuffling.
@llvm/pr-subscribers-backend-amdgpu @llvm/pr-subscribers-mlir-gpu Author: Krzysztof Drewniak (krzysz00) ChangesThe current arith-to-amdgpu patterns for scaling_extf and scaling_truncf don't take full advantage of the native packing ability of the intrinsics being targetted. Scaling extension takes the location of the two elements to be extended as a constant argument (byte for fp4, half for fp8), and scaling truncation takes a 32-bit input register and a byte or half to write the truncated values to. Not using these features would cause excess unneeded register pressure. This PR resolves the inefficiency. It also adds a test for the expected usecase of extending or truncateting a block of 32 values to/from fp4 with a uniform scale to ensure that this usage has a minimal amount of vector shuffling. Patch is 28.80 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/150342.diff 3 Files Affected:
diff --git a/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp b/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
index 8c68b57877c35..4cf80167b20c2 100644
--- a/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
+++ b/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
@@ -449,7 +449,8 @@ LogicalResult
ScalingExtFRewritePattern::matchAndRewrite(arith::ScalingExtFOp op,
PatternRewriter &rewriter) const {
Location loc = op.getLoc();
- constexpr int64_t opWidth = 2;
+ constexpr int64_t opOutWidth = 2;
+ constexpr int64_t opInWidth = 8;
Value in = op.getIn();
Value scale = op.getScale();
@@ -473,7 +474,7 @@ ScalingExtFRewritePattern::matchAndRewrite(arith::ScalingExtFOp op,
else if (scaleType.getIntOrFloatBitWidth() > 32)
scale = arith::TruncFOp::create(rewriter, loc, scaleF32Type, scale);
- VectorType extScaleResultType = VectorType::get(opWidth, outType);
+ VectorType extScaleResultType = VectorType::get(opOutWidth, outType);
if (!outVecType) {
Value inCast = vector::BroadcastOp::create(rewriter, loc,
@@ -487,10 +488,11 @@ ScalingExtFRewritePattern::matchAndRewrite(arith::ScalingExtFOp op,
VectorType inVecType = cast<VectorType>(in.getType());
Value origScale = getOriginalVectorValue(op.getScale());
+ VectorType origScaleVecType = dyn_cast<VectorType>(origScale.getType());
ArrayRef<int64_t> inShape = inVecType.getShape();
SmallVector<int64_t> originalScaleShape;
- if (auto origScaleVecType = dyn_cast<VectorType>(origScale.getType()))
+ if (origScaleVecType)
llvm::append_range(originalScaleShape, origScaleVecType.getShape());
originalScaleShape.insert(originalScaleShape.end(),
@@ -524,19 +526,26 @@ ScalingExtFRewritePattern::matchAndRewrite(arith::ScalingExtFOp op,
Value blockResult =
rewriter.createOrFold<vector::BroadcastOp>(loc, blockResultType, zero);
- for (int64_t i = 0, sliceWidth = std::min(opWidth, blockSize - i);
+ for (int64_t i = 0, inSliceWidth = std::min(opInWidth, blockSize - i);
i < blockSize;
- i += sliceWidth, sliceWidth = std::min(opWidth, blockSize - i)) {
- Value slice = vector::ExtractStridedSliceOp::create(
- rewriter, loc, block1D, i, sliceWidth, 1);
- // TODO: replace this with non-packed ScaledExtOp for sliceWidth == 1
- Value scaleExt = amdgpu::ScaledExtPackedOp::create(
- rewriter, loc, extScaleResultType, slice, uniformScale, 0);
- if (sliceWidth != opWidth)
- scaleExt = vector::ExtractStridedSliceOp::create(
- rewriter, loc, scaleExt, 0, sliceWidth, 1);
- blockResult = vector::InsertStridedSliceOp::create(
- rewriter, loc, scaleExt, blockResult, i, 1);
+ i += inSliceWidth, inSliceWidth = std::min(opInWidth, blockSize - i)) {
+ Value inSlice = vector::ExtractStridedSliceOp::create(
+ rewriter, loc, block1D, i, inSliceWidth, 1);
+ for (int64_t j = 0,
+ outSliceWidth = std::min(opOutWidth, inSliceWidth - j);
+ j < inSliceWidth; j += outSliceWidth,
+ outSliceWidth = std::min(opOutWidth, inSliceWidth - j)) {
+ // TODO: replace this with non-packed ScaledExtOp for sliceWidth == 1
+ Value scaleExt = amdgpu::ScaledExtPackedOp::create(
+ rewriter, loc, extScaleResultType, inSlice, uniformScale,
+ j / opOutWidth);
+ if (outSliceWidth < opOutWidth) {
+ scaleExt = vector::ExtractStridedSliceOp::create(
+ rewriter, loc, scaleExt, 0, outSliceWidth, 1);
+ }
+ blockResult = vector::InsertStridedSliceOp::create(
+ rewriter, loc, scaleExt, blockResult, i + j, 1);
+ }
}
VectorType resultType = VectorType::get(ratio, outType);
@@ -555,7 +564,7 @@ LogicalResult
ScalingTruncFRewritePattern::matchAndRewrite(arith::ScalingTruncFOp op,
PatternRewriter &rewriter) const {
Location loc = op.getLoc();
- constexpr int64_t opWidth = 2;
+ constexpr int64_t opInWidth = 2;
Value in = op.getIn();
Value scale = op.getScale();
@@ -568,7 +577,6 @@ ScalingTruncFRewritePattern::matchAndRewrite(arith::ScalingTruncFOp op,
VectorType outVecType = dyn_cast<VectorType>(out.getType());
VectorType scaleVecType = dyn_cast<VectorType>(scale.getType());
-
if (outVecType && outVecType.isScalable())
return failure();
@@ -581,8 +589,8 @@ ScalingTruncFRewritePattern::matchAndRewrite(arith::ScalingTruncFOp op,
Value zero = arith::ConstantOp::create(rewriter, loc, outType,
rewriter.getFloatAttr(outType, 0.0));
- unsigned numPackedElem = 32 / outType.getIntOrFloatBitWidth();
- VectorType truncScaleResultType = VectorType::get(numPackedElem, outType);
+ int64_t opOutWidth = 32 / outType.getIntOrFloatBitWidth();
+ VectorType truncScaleResultType = VectorType::get(opOutWidth, outType);
if (!outVecType) {
Type inVecType = VectorType::get(1, inType);
@@ -598,16 +606,16 @@ ScalingTruncFRewritePattern::matchAndRewrite(arith::ScalingTruncFOp op,
VectorType inVecType = cast<VectorType>(in.getType());
Value origScale = getOriginalVectorValue(op.getScale());
+ VectorType origScaleVecType = dyn_cast<VectorType>(origScale.getType());
ArrayRef<int64_t> inShape = inVecType.getShape();
- SmallVector<int64_t> originalScaleShape;
- if (auto origScaleVecType = dyn_cast<VectorType>(origScale.getType()))
- llvm::append_range(originalScaleShape, origScaleVecType.getShape());
+ SmallVector<int64_t> scaleShape;
+ if (origScaleVecType)
+ llvm::append_range(scaleShape, origScaleVecType.getShape());
- originalScaleShape.insert(originalScaleShape.end(),
- inShape.size() - originalScaleShape.size(), 1);
+ scaleShape.insert(scaleShape.end(), inShape.size() - scaleShape.size(), 1);
- auto maybeRatio = computeShapeRatio(inShape, originalScaleShape);
+ auto maybeRatio = computeShapeRatio(inShape, scaleShape);
assert(maybeRatio &&
"failed to derive block size from broadcast or splat operation");
@@ -633,20 +641,36 @@ ScalingTruncFRewritePattern::matchAndRewrite(arith::ScalingTruncFOp op,
Value blockResult =
rewriter.createOrFold<vector::BroadcastOp>(loc, blockResultType, zero);
- for (int64_t i = 0, sliceWidth = std::min(opWidth, blockSize - i);
- i < blockSize;
- i += sliceWidth, sliceWidth = std::min(opWidth, blockSize - i)) {
- Value slice = vector::ExtractStridedSliceOp::create(
- rewriter, loc, block1D, i, sliceWidth, 1);
- // TODO: replace this with non-packed ScaledTruncOp for sliceWidth == 1
- Value scaleTrunc = amdgpu::PackedScaledTruncOp::create(
- rewriter, loc, truncScaleResultType, slice, uniformScale, 0,
- /*existing=*/nullptr);
- int64_t packedWidth =
- cast<VectorType>(scaleTrunc.getType()).getNumElements();
- if (packedWidth != opWidth)
+ for (int64_t i = 0, outSliceWidth = std::min(opOutWidth, blockSize - i);
+ i < blockSize; i += outSliceWidth,
+ outSliceWidth = std::min(opOutWidth, blockSize - i)) {
+ Value scaleTrunc;
+ // Case where <= 2 elements are being truncated.
+ if (outSliceWidth <= opInWidth) {
+ Value slice = vector::ExtractStridedSliceOp::create(
+ rewriter, loc, block1D, i, outSliceWidth, 1);
+ // TODO: replace this with non-packed ScaledTruncOp for sliceWidth == 1
+ scaleTrunc = amdgpu::PackedScaledTruncOp::create(
+ rewriter, loc, truncScaleResultType, slice, uniformScale, 0,
+ /*existing=*/nullptr);
+ } else {
+ scaleTrunc = vector::BroadcastOp::create(rewriter, loc,
+ truncScaleResultType, zero);
+ for (int64_t j = 0,
+ inSliceWidth = std::min(opInWidth, outSliceWidth - j);
+ j < outSliceWidth; j += opInWidth,
+ inSliceWidth = std::min(opInWidth, outSliceWidth - j)) {
+ Value slice = vector::ExtractStridedSliceOp::create(
+ rewriter, loc, block1D, i + j, inSliceWidth, 1);
+ scaleTrunc = amdgpu::PackedScaledTruncOp::create(
+ rewriter, loc, truncScaleResultType, slice, uniformScale,
+ j / opInWidth, scaleTrunc);
+ }
+ }
+ if (outSliceWidth != opOutWidth) {
scaleTrunc = vector::ExtractStridedSliceOp::create(
- rewriter, loc, scaleTrunc, 0, sliceWidth, 1);
+ rewriter, loc, scaleTrunc, 0, outSliceWidth, 1);
+ }
blockResult = vector::InsertStridedSliceOp::create(
rewriter, loc, scaleTrunc, blockResult, i, 1);
}
diff --git a/mlir/test/Conversion/ArithToAMDGPU/scaling-extf.mlir b/mlir/test/Conversion/ArithToAMDGPU/scaling-extf.mlir
index b98045195f8cf..a837bdb8be4fa 100644
--- a/mlir/test/Conversion/ArithToAMDGPU/scaling-extf.mlir
+++ b/mlir/test/Conversion/ArithToAMDGPU/scaling-extf.mlir
@@ -163,27 +163,23 @@ func.func @conversion_f4_f16_fallback(%in: vector<2x2xf4E2M1FN>, %scale: vector<
// CHECK-DAG: %[[SCALE_CAST:.+]] = vector.shape_cast %[[BCAST]]
// CHECK-DAG: %[[SCALE_EXT:.+]] = arith.extf %[[SCALE_CAST]]
// CHECK-DAG: vector.extract_strided_slice %[[IN_CAST]] {offsets = [0, 0, 0], sizes = [1, 1, 4], strides = [1, 1, 1]}
-// CHECK-NEXT: vector.shape_cast
+// CHECK-NEXT: %[[IN_SLICE_CAST:.+]] = vector.shape_cast
// CHECK-NEXT: vector.extract %[[SCALE_EXT]][0, 0, 0]
-// CHECK-NEXT: vector.extract_strided_slice %{{.+}} {offsets = [0], sizes = [2], strides = [1]}
-// CHECK-NEXT: amdgpu.scaled_ext_packed
-// CHECK-NEXT: vector.insert_strided_slice %{{.+}}, %{{.+}} {offsets = [0], strides = [1]}
-// CHECK-NEXT: vector.extract_strided_slice %{{.+}} {offsets = [2], sizes = [2], strides = [1]}
-// CHECK-NEXT: amdgpu.scaled_ext_packed
-// CHECK-NEXT: vector.insert_strided_slice %{{.+}}, %{{.+}} {offsets = [2], strides = [1]}
+// CHECK-NEXT: %[[LOWHALF:.+]] = amdgpu.scaled_ext_packed %[[IN_SLICE_CAST]][0]
+// CHECK-NEXT: vector.insert_strided_slice %[[LOWHALF]], %{{.+}} {offsets = [0], strides = [1]}
+// CHECK-NEXT: %[[HIGHHALF:.+]] = amdgpu.scaled_ext_packed %[[IN_SLICE_CAST]][1]
+// CHECK-NEXT: vector.insert_strided_slice %[[HIGHHALF]], %{{.+}} {offsets = [2], strides = [1]}
// CHECK-NEXT: vector.shape_cast
// CHECK-NEXT: vector.insert_strided_slice %{{.+}} {offsets = [0, 0, 0], strides = [1, 1, 1]}
// CHECK-NEXT: vector.extract_strided_slice %[[IN_CAST]] {offsets = [0, 1, 0], sizes = [1, 1, 4], strides = [1, 1, 1]}
// CHECK-NEXT: vector.shape_cast
// CHECK-NEXT: vector.extract %[[SCALE_EXT]][0, 1, 0]
-// CHECK-NEXT: vector.extract_strided_slice %{{.+}} {offsets = [0], sizes = [2], strides = [1]}
// CHECK-NEXT: amdgpu.scaled_ext_packed
// CHECK-NEXT: vector.insert_strided_slice %{{.+}}, %{{.+}} {offsets = [0], strides = [1]}
-// CHECK-NEXT: vector.extract_strided_slice %{{.+}} {offsets = [2], sizes = [2], strides = [1]}
// CHECK-NEXT: amdgpu.scaled_ext_packed
// CHECK-NEXT: vector.insert_strided_slice %{{.+}}, %{{.+}} {offsets = [2], strides = [1]}
// CHECK-NEXT: vector.shape_cast
-// CHECK-NEXT: vector.insert_strided_slice %{{.+}}, %{{.+}} {offsets = [0, 1, 0], strides = [1, 1, 1]}
+// CHECK-NEXT: vector.insert_strided_slice %{{.+}}, %{{.+}} {offsets = [0, 1, 0], strides = [1, 1, 1]}
func.func @conversion_broadcast(%in: vector<8x8xf8E5M2>, %scale: vector<8x2xf8E8M0FNU>) -> vector<8x8xf32> {
%bc = vector.broadcast %scale : vector<8x2xf8E8M0FNU> to vector<4x8x2xf8E8M0FNU>
%cast1 = vector.shape_cast %in : vector<8x8xf8E5M2> to vector<8x2x4xf8E5M2>
@@ -203,21 +199,17 @@ func.func @conversion_broadcast(%in: vector<8x8xf8E5M2>, %scale: vector<8x2xf8E8
// CHECK-NEXT: %[[SCALE_EXT:.+]] = arith.extf %[[SCALE_FLAT]] : vector<6xf8E8M0FNU> to vector<6xf32>
// CHECK-NEXT: %[[IN_SLICE_0:.+]] = vector.extract_strided_slice %arg0 {offsets = [0], sizes = [3], strides = [1]} : vector<6xf8E5M2> to vector<3xf8E5M2>
// CHECK-NEXT: %[[SCALE_SCALAR_0:.+]] = vector.extract %[[SCALE_EXT]][0] : f32 from vector<6xf32>
-// CHECK-NEXT: %[[IN_CHUNK_0A:.+]] = vector.extract_strided_slice %[[IN_SLICE_0]] {offsets = [0], sizes = [2], strides = [1]} : vector<3xf8E5M2> to vector<2xf8E5M2>
-// CHECK-NEXT: %[[PACKED_0A:.+]] = amdgpu.scaled_ext_packed %[[IN_CHUNK_0A]][0], %[[SCALE_SCALAR_0]] : vector<2xf8E5M2> to vector<2xf32>
+// CHECK-NEXT: %[[PACKED_0A:.+]] = amdgpu.scaled_ext_packed %[[IN_SLICE_0]][0], %[[SCALE_SCALAR_0]] : vector<3xf8E5M2> to vector<2xf32>
// CHECK-NEXT: %[[PARTIAL_ACC_0:.+]] = vector.insert_strided_slice %[[PACKED_0A]], %[[CST_PARTIAL]] {offsets = [0], strides = [1]} : vector<2xf32> into vector<3xf32>
-// CHECK-NEXT: %[[IN_CHUNK_0B:.+]] = vector.extract_strided_slice %[[IN_SLICE_0]] {offsets = [2], sizes = [1], strides = [1]} : vector<3xf8E5M2> to vector<1xf8E5M2>
-// CHECK-NEXT: %[[PACKED_0B_RAW:.+]] = amdgpu.scaled_ext_packed %[[IN_CHUNK_0B]][0], %[[SCALE_SCALAR_0]] : vector<1xf8E5M2> to vector<2xf32>
+// CHECK-NEXT: %[[PACKED_0B_RAW:.+]] = amdgpu.scaled_ext_packed %[[IN_SLICE_0]][1], %[[SCALE_SCALAR_0]] : vector<3xf8E5M2> to vector<2xf32>
// CHECK-NEXT: %[[PACKED_0B:.+]] = vector.extract_strided_slice %[[PACKED_0B_RAW]] {offsets = [0], sizes = [1], strides = [1]} : vector<2xf32> to vector<1xf32>
// CHECK-NEXT: %[[OUT_SLICE_0:.+]] = vector.insert_strided_slice %[[PACKED_0B]], %[[PARTIAL_ACC_0]] {offsets = [2], strides = [1]} : vector<1xf32> into vector<3xf32>
// CHECK-NEXT: %[[FINAL_ACC_A:.+]] = vector.insert_strided_slice %[[OUT_SLICE_0]], %[[CST_FINAL]] {offsets = [0], strides = [1]} : vector<3xf32> into vector<6xf32>
// CHECK-NEXT: %[[IN_SLICE_1:.+]] = vector.extract_strided_slice %arg0 {offsets = [3], sizes = [3], strides = [1]} : vector<6xf8E5M2> to vector<3xf8E5M2>
// CHECK-NEXT: %[[SCALE_SCALAR_1:.+]] = vector.extract %[[SCALE_EXT]][3] : f32 from vector<6xf32>
-// CHECK-NEXT: %[[IN_CHUNK_1A:.+]] = vector.extract_strided_slice %[[IN_SLICE_1]] {offsets = [0], sizes = [2], strides = [1]} : vector<3xf8E5M2> to vector<2xf8E5M2>
-// CHECK-NEXT: %[[PACKED_1A:.+]] = amdgpu.scaled_ext_packed %[[IN_CHUNK_1A]][0], %[[SCALE_SCALAR_1]] : vector<2xf8E5M2> to vector<2xf32>
+// CHECK-NEXT: %[[PACKED_1A:.+]] = amdgpu.scaled_ext_packed %[[IN_SLICE_1]][0], %[[SCALE_SCALAR_1]] : vector<3xf8E5M2> to vector<2xf32>
// CHECK-NEXT: %[[PARTIAL_ACC_1:.+]] = vector.insert_strided_slice %[[PACKED_1A]], %[[CST_PARTIAL]] {offsets = [0], strides = [1]} : vector<2xf32> into vector<3xf32>
-// CHECK-NEXT: %[[IN_CHUNK_1B:.+]] = vector.extract_strided_slice %[[IN_SLICE_1]] {offsets = [2], sizes = [1], strides = [1]} : vector<3xf8E5M2> to vector<1xf8E5M2>
-// CHECK-NEXT: %[[PACKED_1B_RAW:.+]] = amdgpu.scaled_ext_packed %[[IN_CHUNK_1B]][0], %[[SCALE_SCALAR_1]] : vector<1xf8E5M2> to vector<2xf32>
+// CHECK-NEXT: %[[PACKED_1B_RAW:.+]] = amdgpu.scaled_ext_packed %[[IN_SLICE_1]][1], %[[SCALE_SCALAR_1]] : vector<3xf8E5M2> to vector<2xf32>
// CHECK-NEXT: %[[PACKED_1B:.+]] = vector.extract_strided_slice %[[PACKED_1B_RAW]] {offsets = [0], sizes = [1], strides = [1]} : vector<2xf32> to vector<1xf32>
// CHECK-NEXT: %[[OUT_SLICE_1:.+]] = vector.insert_strided_slice %[[PACKED_1B]], %[[PARTIAL_ACC_1]] {offsets = [2], strides = [1]} : vector<1xf32> into vector<3xf32>
// CHECK-NEXT: %[[RESULT:.+]] = vector.insert_strided_slice %[[OUT_SLICE_1]], %[[FINAL_ACC_A]] {offsets = [3], strides = [1]} : vector<3xf32> into vector<6xf32>
@@ -236,11 +228,9 @@ func.func @conversion_broadcast_odd(%in: vector<6xf8E5M2>, %scale: vector<2xf8E8
// CHECK-DAG: %[[SCALE_SPLAT:.+]] = vector.broadcast %arg1 : f8E8M0FNU to vector<4xf8E8M0FNU>
// CHECK-DAG: %[[SCALE_EXTF:.+]] = arith.extf %[[SCALE_SPLAT]] : vector<4xf8E8M0FNU> to vector<4xf32>
// CHECK-DAG: %[[SCALE_SCALAR:.+]] = vector.extract %[[SCALE_EXTF]][0] : f32 from vector<4xf32>
-// CHECK: %[[IN_CHUNK0:.+]] = vector.extract_strided_slice %arg0 {offsets = [0], sizes = [2], strides = [1]} : vector<4xf8E5M2> to vector<2xf8E5M2>
-// CHECK-NEXT: %[[OUT_CHUNK0:.+]] = amdgpu.scaled_ext_packed %[[IN_CHUNK0]][0], %[[SCALE_SCALAR]] : vector<2xf8E5M2> to vector<2xf32>
+// CHECK: %[[OUT_CHUNK0:.+]] = amdgpu.scaled_ext_packed %arg0[0], %[[SCALE_SCALAR]] : vector<4xf8E5M2> to vector<2xf32>
// CHECK-NEXT: %[[ACCUM_A:.+]] = vector.insert_strided_slice %[[OUT_CHUNK0]], %[[CST]] {offsets = [0], strides = [1]} : vector<2xf32> into vector<4xf32>
-// CHECK-NEXT: %[[IN_CHUNK1:.+]] = vector.extract_strided_slice %arg0 {offsets = [2], sizes = [2], strides = [1]} : vector<4xf8E5M2> to vector<2xf8E5M2>
-// CHECK-NEXT: %[[OUT_CHUNK1:.+]] = amdgpu.scaled_ext_packed %[[IN_CHUNK1]][0], %[[SCALE_SCALAR]] : vector<2xf8E5M2> to vector<2xf32>
+// CHECK-NEXT: %[[OUT_CHUNK1:.+]] = amdgpu.scaled_ext_packed %arg0[1], %[[SCALE_SCALAR]] : vector<4xf8E5M2> to vector<2xf32>
// CHECK-NEXT: %[[FINAL_RESULT:.+]] = vector.insert_strided_slice %[[OUT_CHUNK1]], %[[ACCUM_A]] {offsets = [2], strides = [1]} : vector<2xf32> into vector<4xf32>
// CHECK-NEXT: return %[[FINAL_RESULT]] : vector<4xf32>
func.func @conversion_broadcast(%in: vector<4xf8E5M2>, %scale: f8E8M0FNU) -> vector<4xf32> {
@@ -261,3 +251,15 @@ func.func @conversion_scalar(%in: f8E5M2, %scale: f8E8M0FNU) -> f32 {
%ext = arith.scaling_extf %in, %scale : f8E5M2, f8E8M0FNU to f32
return %ext : f32
}
+
+// -----
+
+// CHECK-LABEL: @long_fp4_broadcast
+// CHECK-COUNT-4: amdgpu.scaled_ext_packed %{{.+}}[3]
+// CHECK-NOT: amdgpu.scaled_ext_packed
+// CHECK: return
+func.func @long_fp4_broadcast(%in: vector<32xf4E2M1FN>, %scale: f32) -> vector<32xf32> {
+ %splat = vector.broadcast %scale : f32 to vector<32xf32>
+ %ext = arith.scaling_extf %in, %splat : vector<32xf4E2M1FN>, vector<32xf32> to vector<32xf32>
+ return %ext : vector<32xf32>
+}
diff --git a/mlir/test/Conversion/ArithToAMDGPU/scaling-truncf.mlir b/mlir/test/Conversion/ArithToAMDGPU/scaling-truncf.mlir
index 488e75cbb1843..6d6e1e28d2c2c 100644
--- a/mlir/test/Conversion/ArithToAMDGPU/scaling-truncf.mlir
+++ b/mlir/test/Conversion/ArithToAMDGPU/scaling-truncf.mlir
@@ -88,28 +88,20 @@ func.func @conversion_f4_fallback(%in: vector<2x2xf32>, %scale: vector<2x2xf8E8M
// CHECK-NEXT: vector.shape_cast
// CHECK-NEXT: vector.extract %[[SCALE_EXT]][0, 0, 0]
// CHECK-NEXT: vector.extract_strided_slice %{{.+}} {offsets = [0], sizes = [2], strides = [1]}
-// CHECK-NEXT: amdgpu.packed_scaled_trunc
-// CHECK-NEXT: vector.extract_strided_slice %{{.+}} {offsets = [0], sizes = [2], strides = [1]}
-// CHECK-NEXT: vector.insert_strided_slice %{{.+}}, %{{.+}} {offsets = [0], strides = [1]}
+// CHECK-NEXT: %[[P1:.+]] = amdgpu.packed_scaled_trunc
// CHECK-NEXT: vector.extract_strided_slice %{{.+}} {offsets = [2], sizes = [2], strides = [1]}
-// CHECK-NEXT: amdgpu.packed_scaled_trunc
-// CHECK-NEXT: vector.extract_strided_slice
-// CHECK-NEXT: vector.insert_strided_slice %{{.+}}, %{{.+}} {offsets = [2], strides = [1]}
-// CHECK-NEXT: vector.shape_cast
-// CHECK-NEXT: vector.insert_strided_slice %{{.+}} {offsets = [0, 0, 0], strides = [1, 1, 1]}
+// CHECK-NEXT: %[[P2:.+]] = amdgpu.packed_scaled_trunc {{.*}} into %[[P1]][1]
+// CHECK-NEXT: %[[P2_CAST:.+]] = vector.shape_cast %[[P2]] : vector<4xf8E5M2> to vector<1x1x4xf8E5M2>
+// CHECK-NEXT: vector.insert_strided_slice %[[P2_CAST]], %{{.+}} {offsets = [0, 0, 0], strides = [1, 1, 1]}
// CHECK-NEXT: vector.extract_strided_slice %[[IN_CAST]] {offsets = [0, 1, 0], sizes = [1, 1, 4], strides = [1, 1, 1]}
// CHECK-NEXT: vector.shape_cast
// CHECK-NEXT: vector.extract %[[SCALE_EXT]][0, 1, 0]
// CHECK-NEXT: vector.extract_strided_slice %{{.+}} {offsets = [0], sizes = [2], strides = [1]}
// CHECK-NEXT: amdgpu.packed_scaled_trunc
-// CHECK-NEXT: vector.extract_strided_slice %{{.+}} {offsets = [0], sizes = [2], strides = [1]}
-// CHECK-NEXT: vector.insert_strided_slice %{{.+}}, %{{.+}} {offsets = [0], strides = [1]}
// CHECK-NEXT: vector.extract_strided_slice %{{.+}} {offsets = [2], sizes = [2], strides = [1]}
// CHECK-NEXT: amdgpu.packed_scaled_trunc
-// CHECK-NEXT: vector.extract_strided_slice %{{.+}} {offsets = [0], sizes = [2], strides = [1]}
-// CHECK-N...
[truncated]
|
@llvm/pr-subscribers-mlir Author: Krzysztof Drewniak (krzysz00) ChangesThe current arith-to-amdgpu patterns for scaling_extf and scaling_truncf don't take full advantage of the native packing ability of the intrinsics being targetted. Scaling extension takes the location of the two elements to be extended as a constant argument (byte for fp4, half for fp8), and scaling truncation takes a 32-bit input register and a byte or half to write the truncated values to. Not using these features would cause excess unneeded register pressure. This PR resolves the inefficiency. It also adds a test for the expected usecase of extending or truncateting a block of 32 values to/from fp4 with a uniform scale to ensure that this usage has a minimal amount of vector shuffling. Patch is 28.80 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/150342.diff 3 Files Affected:
diff --git a/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp b/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
index 8c68b57877c35..4cf80167b20c2 100644
--- a/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
+++ b/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
@@ -449,7 +449,8 @@ LogicalResult
ScalingExtFRewritePattern::matchAndRewrite(arith::ScalingExtFOp op,
PatternRewriter &rewriter) const {
Location loc = op.getLoc();
- constexpr int64_t opWidth = 2;
+ constexpr int64_t opOutWidth = 2;
+ constexpr int64_t opInWidth = 8;
Value in = op.getIn();
Value scale = op.getScale();
@@ -473,7 +474,7 @@ ScalingExtFRewritePattern::matchAndRewrite(arith::ScalingExtFOp op,
else if (scaleType.getIntOrFloatBitWidth() > 32)
scale = arith::TruncFOp::create(rewriter, loc, scaleF32Type, scale);
- VectorType extScaleResultType = VectorType::get(opWidth, outType);
+ VectorType extScaleResultType = VectorType::get(opOutWidth, outType);
if (!outVecType) {
Value inCast = vector::BroadcastOp::create(rewriter, loc,
@@ -487,10 +488,11 @@ ScalingExtFRewritePattern::matchAndRewrite(arith::ScalingExtFOp op,
VectorType inVecType = cast<VectorType>(in.getType());
Value origScale = getOriginalVectorValue(op.getScale());
+ VectorType origScaleVecType = dyn_cast<VectorType>(origScale.getType());
ArrayRef<int64_t> inShape = inVecType.getShape();
SmallVector<int64_t> originalScaleShape;
- if (auto origScaleVecType = dyn_cast<VectorType>(origScale.getType()))
+ if (origScaleVecType)
llvm::append_range(originalScaleShape, origScaleVecType.getShape());
originalScaleShape.insert(originalScaleShape.end(),
@@ -524,19 +526,26 @@ ScalingExtFRewritePattern::matchAndRewrite(arith::ScalingExtFOp op,
Value blockResult =
rewriter.createOrFold<vector::BroadcastOp>(loc, blockResultType, zero);
- for (int64_t i = 0, sliceWidth = std::min(opWidth, blockSize - i);
+ for (int64_t i = 0, inSliceWidth = std::min(opInWidth, blockSize - i);
i < blockSize;
- i += sliceWidth, sliceWidth = std::min(opWidth, blockSize - i)) {
- Value slice = vector::ExtractStridedSliceOp::create(
- rewriter, loc, block1D, i, sliceWidth, 1);
- // TODO: replace this with non-packed ScaledExtOp for sliceWidth == 1
- Value scaleExt = amdgpu::ScaledExtPackedOp::create(
- rewriter, loc, extScaleResultType, slice, uniformScale, 0);
- if (sliceWidth != opWidth)
- scaleExt = vector::ExtractStridedSliceOp::create(
- rewriter, loc, scaleExt, 0, sliceWidth, 1);
- blockResult = vector::InsertStridedSliceOp::create(
- rewriter, loc, scaleExt, blockResult, i, 1);
+ i += inSliceWidth, inSliceWidth = std::min(opInWidth, blockSize - i)) {
+ Value inSlice = vector::ExtractStridedSliceOp::create(
+ rewriter, loc, block1D, i, inSliceWidth, 1);
+ for (int64_t j = 0,
+ outSliceWidth = std::min(opOutWidth, inSliceWidth - j);
+ j < inSliceWidth; j += outSliceWidth,
+ outSliceWidth = std::min(opOutWidth, inSliceWidth - j)) {
+ // TODO: replace this with non-packed ScaledExtOp for sliceWidth == 1
+ Value scaleExt = amdgpu::ScaledExtPackedOp::create(
+ rewriter, loc, extScaleResultType, inSlice, uniformScale,
+ j / opOutWidth);
+ if (outSliceWidth < opOutWidth) {
+ scaleExt = vector::ExtractStridedSliceOp::create(
+ rewriter, loc, scaleExt, 0, outSliceWidth, 1);
+ }
+ blockResult = vector::InsertStridedSliceOp::create(
+ rewriter, loc, scaleExt, blockResult, i + j, 1);
+ }
}
VectorType resultType = VectorType::get(ratio, outType);
@@ -555,7 +564,7 @@ LogicalResult
ScalingTruncFRewritePattern::matchAndRewrite(arith::ScalingTruncFOp op,
PatternRewriter &rewriter) const {
Location loc = op.getLoc();
- constexpr int64_t opWidth = 2;
+ constexpr int64_t opInWidth = 2;
Value in = op.getIn();
Value scale = op.getScale();
@@ -568,7 +577,6 @@ ScalingTruncFRewritePattern::matchAndRewrite(arith::ScalingTruncFOp op,
VectorType outVecType = dyn_cast<VectorType>(out.getType());
VectorType scaleVecType = dyn_cast<VectorType>(scale.getType());
-
if (outVecType && outVecType.isScalable())
return failure();
@@ -581,8 +589,8 @@ ScalingTruncFRewritePattern::matchAndRewrite(arith::ScalingTruncFOp op,
Value zero = arith::ConstantOp::create(rewriter, loc, outType,
rewriter.getFloatAttr(outType, 0.0));
- unsigned numPackedElem = 32 / outType.getIntOrFloatBitWidth();
- VectorType truncScaleResultType = VectorType::get(numPackedElem, outType);
+ int64_t opOutWidth = 32 / outType.getIntOrFloatBitWidth();
+ VectorType truncScaleResultType = VectorType::get(opOutWidth, outType);
if (!outVecType) {
Type inVecType = VectorType::get(1, inType);
@@ -598,16 +606,16 @@ ScalingTruncFRewritePattern::matchAndRewrite(arith::ScalingTruncFOp op,
VectorType inVecType = cast<VectorType>(in.getType());
Value origScale = getOriginalVectorValue(op.getScale());
+ VectorType origScaleVecType = dyn_cast<VectorType>(origScale.getType());
ArrayRef<int64_t> inShape = inVecType.getShape();
- SmallVector<int64_t> originalScaleShape;
- if (auto origScaleVecType = dyn_cast<VectorType>(origScale.getType()))
- llvm::append_range(originalScaleShape, origScaleVecType.getShape());
+ SmallVector<int64_t> scaleShape;
+ if (origScaleVecType)
+ llvm::append_range(scaleShape, origScaleVecType.getShape());
- originalScaleShape.insert(originalScaleShape.end(),
- inShape.size() - originalScaleShape.size(), 1);
+ scaleShape.insert(scaleShape.end(), inShape.size() - scaleShape.size(), 1);
- auto maybeRatio = computeShapeRatio(inShape, originalScaleShape);
+ auto maybeRatio = computeShapeRatio(inShape, scaleShape);
assert(maybeRatio &&
"failed to derive block size from broadcast or splat operation");
@@ -633,20 +641,36 @@ ScalingTruncFRewritePattern::matchAndRewrite(arith::ScalingTruncFOp op,
Value blockResult =
rewriter.createOrFold<vector::BroadcastOp>(loc, blockResultType, zero);
- for (int64_t i = 0, sliceWidth = std::min(opWidth, blockSize - i);
- i < blockSize;
- i += sliceWidth, sliceWidth = std::min(opWidth, blockSize - i)) {
- Value slice = vector::ExtractStridedSliceOp::create(
- rewriter, loc, block1D, i, sliceWidth, 1);
- // TODO: replace this with non-packed ScaledTruncOp for sliceWidth == 1
- Value scaleTrunc = amdgpu::PackedScaledTruncOp::create(
- rewriter, loc, truncScaleResultType, slice, uniformScale, 0,
- /*existing=*/nullptr);
- int64_t packedWidth =
- cast<VectorType>(scaleTrunc.getType()).getNumElements();
- if (packedWidth != opWidth)
+ for (int64_t i = 0, outSliceWidth = std::min(opOutWidth, blockSize - i);
+ i < blockSize; i += outSliceWidth,
+ outSliceWidth = std::min(opOutWidth, blockSize - i)) {
+ Value scaleTrunc;
+ // Case where <= 2 elements are being truncated.
+ if (outSliceWidth <= opInWidth) {
+ Value slice = vector::ExtractStridedSliceOp::create(
+ rewriter, loc, block1D, i, outSliceWidth, 1);
+ // TODO: replace this with non-packed ScaledTruncOp for sliceWidth == 1
+ scaleTrunc = amdgpu::PackedScaledTruncOp::create(
+ rewriter, loc, truncScaleResultType, slice, uniformScale, 0,
+ /*existing=*/nullptr);
+ } else {
+ scaleTrunc = vector::BroadcastOp::create(rewriter, loc,
+ truncScaleResultType, zero);
+ for (int64_t j = 0,
+ inSliceWidth = std::min(opInWidth, outSliceWidth - j);
+ j < outSliceWidth; j += opInWidth,
+ inSliceWidth = std::min(opInWidth, outSliceWidth - j)) {
+ Value slice = vector::ExtractStridedSliceOp::create(
+ rewriter, loc, block1D, i + j, inSliceWidth, 1);
+ scaleTrunc = amdgpu::PackedScaledTruncOp::create(
+ rewriter, loc, truncScaleResultType, slice, uniformScale,
+ j / opInWidth, scaleTrunc);
+ }
+ }
+ if (outSliceWidth != opOutWidth) {
scaleTrunc = vector::ExtractStridedSliceOp::create(
- rewriter, loc, scaleTrunc, 0, sliceWidth, 1);
+ rewriter, loc, scaleTrunc, 0, outSliceWidth, 1);
+ }
blockResult = vector::InsertStridedSliceOp::create(
rewriter, loc, scaleTrunc, blockResult, i, 1);
}
diff --git a/mlir/test/Conversion/ArithToAMDGPU/scaling-extf.mlir b/mlir/test/Conversion/ArithToAMDGPU/scaling-extf.mlir
index b98045195f8cf..a837bdb8be4fa 100644
--- a/mlir/test/Conversion/ArithToAMDGPU/scaling-extf.mlir
+++ b/mlir/test/Conversion/ArithToAMDGPU/scaling-extf.mlir
@@ -163,27 +163,23 @@ func.func @conversion_f4_f16_fallback(%in: vector<2x2xf4E2M1FN>, %scale: vector<
// CHECK-DAG: %[[SCALE_CAST:.+]] = vector.shape_cast %[[BCAST]]
// CHECK-DAG: %[[SCALE_EXT:.+]] = arith.extf %[[SCALE_CAST]]
// CHECK-DAG: vector.extract_strided_slice %[[IN_CAST]] {offsets = [0, 0, 0], sizes = [1, 1, 4], strides = [1, 1, 1]}
-// CHECK-NEXT: vector.shape_cast
+// CHECK-NEXT: %[[IN_SLICE_CAST:.+]] = vector.shape_cast
// CHECK-NEXT: vector.extract %[[SCALE_EXT]][0, 0, 0]
-// CHECK-NEXT: vector.extract_strided_slice %{{.+}} {offsets = [0], sizes = [2], strides = [1]}
-// CHECK-NEXT: amdgpu.scaled_ext_packed
-// CHECK-NEXT: vector.insert_strided_slice %{{.+}}, %{{.+}} {offsets = [0], strides = [1]}
-// CHECK-NEXT: vector.extract_strided_slice %{{.+}} {offsets = [2], sizes = [2], strides = [1]}
-// CHECK-NEXT: amdgpu.scaled_ext_packed
-// CHECK-NEXT: vector.insert_strided_slice %{{.+}}, %{{.+}} {offsets = [2], strides = [1]}
+// CHECK-NEXT: %[[LOWHALF:.+]] = amdgpu.scaled_ext_packed %[[IN_SLICE_CAST]][0]
+// CHECK-NEXT: vector.insert_strided_slice %[[LOWHALF]], %{{.+}} {offsets = [0], strides = [1]}
+// CHECK-NEXT: %[[HIGHHALF:.+]] = amdgpu.scaled_ext_packed %[[IN_SLICE_CAST]][1]
+// CHECK-NEXT: vector.insert_strided_slice %[[HIGHHALF]], %{{.+}} {offsets = [2], strides = [1]}
// CHECK-NEXT: vector.shape_cast
// CHECK-NEXT: vector.insert_strided_slice %{{.+}} {offsets = [0, 0, 0], strides = [1, 1, 1]}
// CHECK-NEXT: vector.extract_strided_slice %[[IN_CAST]] {offsets = [0, 1, 0], sizes = [1, 1, 4], strides = [1, 1, 1]}
// CHECK-NEXT: vector.shape_cast
// CHECK-NEXT: vector.extract %[[SCALE_EXT]][0, 1, 0]
-// CHECK-NEXT: vector.extract_strided_slice %{{.+}} {offsets = [0], sizes = [2], strides = [1]}
// CHECK-NEXT: amdgpu.scaled_ext_packed
// CHECK-NEXT: vector.insert_strided_slice %{{.+}}, %{{.+}} {offsets = [0], strides = [1]}
-// CHECK-NEXT: vector.extract_strided_slice %{{.+}} {offsets = [2], sizes = [2], strides = [1]}
// CHECK-NEXT: amdgpu.scaled_ext_packed
// CHECK-NEXT: vector.insert_strided_slice %{{.+}}, %{{.+}} {offsets = [2], strides = [1]}
// CHECK-NEXT: vector.shape_cast
-// CHECK-NEXT: vector.insert_strided_slice %{{.+}}, %{{.+}} {offsets = [0, 1, 0], strides = [1, 1, 1]}
+// CHECK-NEXT: vector.insert_strided_slice %{{.+}}, %{{.+}} {offsets = [0, 1, 0], strides = [1, 1, 1]}
func.func @conversion_broadcast(%in: vector<8x8xf8E5M2>, %scale: vector<8x2xf8E8M0FNU>) -> vector<8x8xf32> {
%bc = vector.broadcast %scale : vector<8x2xf8E8M0FNU> to vector<4x8x2xf8E8M0FNU>
%cast1 = vector.shape_cast %in : vector<8x8xf8E5M2> to vector<8x2x4xf8E5M2>
@@ -203,21 +199,17 @@ func.func @conversion_broadcast(%in: vector<8x8xf8E5M2>, %scale: vector<8x2xf8E8
// CHECK-NEXT: %[[SCALE_EXT:.+]] = arith.extf %[[SCALE_FLAT]] : vector<6xf8E8M0FNU> to vector<6xf32>
// CHECK-NEXT: %[[IN_SLICE_0:.+]] = vector.extract_strided_slice %arg0 {offsets = [0], sizes = [3], strides = [1]} : vector<6xf8E5M2> to vector<3xf8E5M2>
// CHECK-NEXT: %[[SCALE_SCALAR_0:.+]] = vector.extract %[[SCALE_EXT]][0] : f32 from vector<6xf32>
-// CHECK-NEXT: %[[IN_CHUNK_0A:.+]] = vector.extract_strided_slice %[[IN_SLICE_0]] {offsets = [0], sizes = [2], strides = [1]} : vector<3xf8E5M2> to vector<2xf8E5M2>
-// CHECK-NEXT: %[[PACKED_0A:.+]] = amdgpu.scaled_ext_packed %[[IN_CHUNK_0A]][0], %[[SCALE_SCALAR_0]] : vector<2xf8E5M2> to vector<2xf32>
+// CHECK-NEXT: %[[PACKED_0A:.+]] = amdgpu.scaled_ext_packed %[[IN_SLICE_0]][0], %[[SCALE_SCALAR_0]] : vector<3xf8E5M2> to vector<2xf32>
// CHECK-NEXT: %[[PARTIAL_ACC_0:.+]] = vector.insert_strided_slice %[[PACKED_0A]], %[[CST_PARTIAL]] {offsets = [0], strides = [1]} : vector<2xf32> into vector<3xf32>
-// CHECK-NEXT: %[[IN_CHUNK_0B:.+]] = vector.extract_strided_slice %[[IN_SLICE_0]] {offsets = [2], sizes = [1], strides = [1]} : vector<3xf8E5M2> to vector<1xf8E5M2>
-// CHECK-NEXT: %[[PACKED_0B_RAW:.+]] = amdgpu.scaled_ext_packed %[[IN_CHUNK_0B]][0], %[[SCALE_SCALAR_0]] : vector<1xf8E5M2> to vector<2xf32>
+// CHECK-NEXT: %[[PACKED_0B_RAW:.+]] = amdgpu.scaled_ext_packed %[[IN_SLICE_0]][1], %[[SCALE_SCALAR_0]] : vector<3xf8E5M2> to vector<2xf32>
// CHECK-NEXT: %[[PACKED_0B:.+]] = vector.extract_strided_slice %[[PACKED_0B_RAW]] {offsets = [0], sizes = [1], strides = [1]} : vector<2xf32> to vector<1xf32>
// CHECK-NEXT: %[[OUT_SLICE_0:.+]] = vector.insert_strided_slice %[[PACKED_0B]], %[[PARTIAL_ACC_0]] {offsets = [2], strides = [1]} : vector<1xf32> into vector<3xf32>
// CHECK-NEXT: %[[FINAL_ACC_A:.+]] = vector.insert_strided_slice %[[OUT_SLICE_0]], %[[CST_FINAL]] {offsets = [0], strides = [1]} : vector<3xf32> into vector<6xf32>
// CHECK-NEXT: %[[IN_SLICE_1:.+]] = vector.extract_strided_slice %arg0 {offsets = [3], sizes = [3], strides = [1]} : vector<6xf8E5M2> to vector<3xf8E5M2>
// CHECK-NEXT: %[[SCALE_SCALAR_1:.+]] = vector.extract %[[SCALE_EXT]][3] : f32 from vector<6xf32>
-// CHECK-NEXT: %[[IN_CHUNK_1A:.+]] = vector.extract_strided_slice %[[IN_SLICE_1]] {offsets = [0], sizes = [2], strides = [1]} : vector<3xf8E5M2> to vector<2xf8E5M2>
-// CHECK-NEXT: %[[PACKED_1A:.+]] = amdgpu.scaled_ext_packed %[[IN_CHUNK_1A]][0], %[[SCALE_SCALAR_1]] : vector<2xf8E5M2> to vector<2xf32>
+// CHECK-NEXT: %[[PACKED_1A:.+]] = amdgpu.scaled_ext_packed %[[IN_SLICE_1]][0], %[[SCALE_SCALAR_1]] : vector<3xf8E5M2> to vector<2xf32>
// CHECK-NEXT: %[[PARTIAL_ACC_1:.+]] = vector.insert_strided_slice %[[PACKED_1A]], %[[CST_PARTIAL]] {offsets = [0], strides = [1]} : vector<2xf32> into vector<3xf32>
-// CHECK-NEXT: %[[IN_CHUNK_1B:.+]] = vector.extract_strided_slice %[[IN_SLICE_1]] {offsets = [2], sizes = [1], strides = [1]} : vector<3xf8E5M2> to vector<1xf8E5M2>
-// CHECK-NEXT: %[[PACKED_1B_RAW:.+]] = amdgpu.scaled_ext_packed %[[IN_CHUNK_1B]][0], %[[SCALE_SCALAR_1]] : vector<1xf8E5M2> to vector<2xf32>
+// CHECK-NEXT: %[[PACKED_1B_RAW:.+]] = amdgpu.scaled_ext_packed %[[IN_SLICE_1]][1], %[[SCALE_SCALAR_1]] : vector<3xf8E5M2> to vector<2xf32>
// CHECK-NEXT: %[[PACKED_1B:.+]] = vector.extract_strided_slice %[[PACKED_1B_RAW]] {offsets = [0], sizes = [1], strides = [1]} : vector<2xf32> to vector<1xf32>
// CHECK-NEXT: %[[OUT_SLICE_1:.+]] = vector.insert_strided_slice %[[PACKED_1B]], %[[PARTIAL_ACC_1]] {offsets = [2], strides = [1]} : vector<1xf32> into vector<3xf32>
// CHECK-NEXT: %[[RESULT:.+]] = vector.insert_strided_slice %[[OUT_SLICE_1]], %[[FINAL_ACC_A]] {offsets = [3], strides = [1]} : vector<3xf32> into vector<6xf32>
@@ -236,11 +228,9 @@ func.func @conversion_broadcast_odd(%in: vector<6xf8E5M2>, %scale: vector<2xf8E8
// CHECK-DAG: %[[SCALE_SPLAT:.+]] = vector.broadcast %arg1 : f8E8M0FNU to vector<4xf8E8M0FNU>
// CHECK-DAG: %[[SCALE_EXTF:.+]] = arith.extf %[[SCALE_SPLAT]] : vector<4xf8E8M0FNU> to vector<4xf32>
// CHECK-DAG: %[[SCALE_SCALAR:.+]] = vector.extract %[[SCALE_EXTF]][0] : f32 from vector<4xf32>
-// CHECK: %[[IN_CHUNK0:.+]] = vector.extract_strided_slice %arg0 {offsets = [0], sizes = [2], strides = [1]} : vector<4xf8E5M2> to vector<2xf8E5M2>
-// CHECK-NEXT: %[[OUT_CHUNK0:.+]] = amdgpu.scaled_ext_packed %[[IN_CHUNK0]][0], %[[SCALE_SCALAR]] : vector<2xf8E5M2> to vector<2xf32>
+// CHECK: %[[OUT_CHUNK0:.+]] = amdgpu.scaled_ext_packed %arg0[0], %[[SCALE_SCALAR]] : vector<4xf8E5M2> to vector<2xf32>
// CHECK-NEXT: %[[ACCUM_A:.+]] = vector.insert_strided_slice %[[OUT_CHUNK0]], %[[CST]] {offsets = [0], strides = [1]} : vector<2xf32> into vector<4xf32>
-// CHECK-NEXT: %[[IN_CHUNK1:.+]] = vector.extract_strided_slice %arg0 {offsets = [2], sizes = [2], strides = [1]} : vector<4xf8E5M2> to vector<2xf8E5M2>
-// CHECK-NEXT: %[[OUT_CHUNK1:.+]] = amdgpu.scaled_ext_packed %[[IN_CHUNK1]][0], %[[SCALE_SCALAR]] : vector<2xf8E5M2> to vector<2xf32>
+// CHECK-NEXT: %[[OUT_CHUNK1:.+]] = amdgpu.scaled_ext_packed %arg0[1], %[[SCALE_SCALAR]] : vector<4xf8E5M2> to vector<2xf32>
// CHECK-NEXT: %[[FINAL_RESULT:.+]] = vector.insert_strided_slice %[[OUT_CHUNK1]], %[[ACCUM_A]] {offsets = [2], strides = [1]} : vector<2xf32> into vector<4xf32>
// CHECK-NEXT: return %[[FINAL_RESULT]] : vector<4xf32>
func.func @conversion_broadcast(%in: vector<4xf8E5M2>, %scale: f8E8M0FNU) -> vector<4xf32> {
@@ -261,3 +251,15 @@ func.func @conversion_scalar(%in: f8E5M2, %scale: f8E8M0FNU) -> f32 {
%ext = arith.scaling_extf %in, %scale : f8E5M2, f8E8M0FNU to f32
return %ext : f32
}
+
+// -----
+
+// CHECK-LABEL: @long_fp4_broadcast
+// CHECK-COUNT-4: amdgpu.scaled_ext_packed %{{.+}}[3]
+// CHECK-NOT: amdgpu.scaled_ext_packed
+// CHECK: return
+func.func @long_fp4_broadcast(%in: vector<32xf4E2M1FN>, %scale: f32) -> vector<32xf32> {
+ %splat = vector.broadcast %scale : f32 to vector<32xf32>
+ %ext = arith.scaling_extf %in, %splat : vector<32xf4E2M1FN>, vector<32xf32> to vector<32xf32>
+ return %ext : vector<32xf32>
+}
diff --git a/mlir/test/Conversion/ArithToAMDGPU/scaling-truncf.mlir b/mlir/test/Conversion/ArithToAMDGPU/scaling-truncf.mlir
index 488e75cbb1843..6d6e1e28d2c2c 100644
--- a/mlir/test/Conversion/ArithToAMDGPU/scaling-truncf.mlir
+++ b/mlir/test/Conversion/ArithToAMDGPU/scaling-truncf.mlir
@@ -88,28 +88,20 @@ func.func @conversion_f4_fallback(%in: vector<2x2xf32>, %scale: vector<2x2xf8E8M
// CHECK-NEXT: vector.shape_cast
// CHECK-NEXT: vector.extract %[[SCALE_EXT]][0, 0, 0]
// CHECK-NEXT: vector.extract_strided_slice %{{.+}} {offsets = [0], sizes = [2], strides = [1]}
-// CHECK-NEXT: amdgpu.packed_scaled_trunc
-// CHECK-NEXT: vector.extract_strided_slice %{{.+}} {offsets = [0], sizes = [2], strides = [1]}
-// CHECK-NEXT: vector.insert_strided_slice %{{.+}}, %{{.+}} {offsets = [0], strides = [1]}
+// CHECK-NEXT: %[[P1:.+]] = amdgpu.packed_scaled_trunc
// CHECK-NEXT: vector.extract_strided_slice %{{.+}} {offsets = [2], sizes = [2], strides = [1]}
-// CHECK-NEXT: amdgpu.packed_scaled_trunc
-// CHECK-NEXT: vector.extract_strided_slice
-// CHECK-NEXT: vector.insert_strided_slice %{{.+}}, %{{.+}} {offsets = [2], strides = [1]}
-// CHECK-NEXT: vector.shape_cast
-// CHECK-NEXT: vector.insert_strided_slice %{{.+}} {offsets = [0, 0, 0], strides = [1, 1, 1]}
+// CHECK-NEXT: %[[P2:.+]] = amdgpu.packed_scaled_trunc {{.*}} into %[[P1]][1]
+// CHECK-NEXT: %[[P2_CAST:.+]] = vector.shape_cast %[[P2]] : vector<4xf8E5M2> to vector<1x1x4xf8E5M2>
+// CHECK-NEXT: vector.insert_strided_slice %[[P2_CAST]], %{{.+}} {offsets = [0, 0, 0], strides = [1, 1, 1]}
// CHECK-NEXT: vector.extract_strided_slice %[[IN_CAST]] {offsets = [0, 1, 0], sizes = [1, 1, 4], strides = [1, 1, 1]}
// CHECK-NEXT: vector.shape_cast
// CHECK-NEXT: vector.extract %[[SCALE_EXT]][0, 1, 0]
// CHECK-NEXT: vector.extract_strided_slice %{{.+}} {offsets = [0], sizes = [2], strides = [1]}
// CHECK-NEXT: amdgpu.packed_scaled_trunc
-// CHECK-NEXT: vector.extract_strided_slice %{{.+}} {offsets = [0], sizes = [2], strides = [1]}
-// CHECK-NEXT: vector.insert_strided_slice %{{.+}}, %{{.+}} {offsets = [0], strides = [1]}
// CHECK-NEXT: vector.extract_strided_slice %{{.+}} {offsets = [2], sizes = [2], strides = [1]}
// CHECK-NEXT: amdgpu.packed_scaled_trunc
-// CHECK-NEXT: vector.extract_strided_slice %{{.+}} {offsets = [0], sizes = [2], strides = [1]}
-// CHECK-N...
[truncated]
|
The current arith-to-amdgpu patterns for scaling_extf and scaling_truncf don't take full advantage of the native packing ability of the intrinsics being targetted. Scaling extension takes the location of the two elements to be extended as a constant argument (byte for fp4, half for fp8), and scaling truncation takes a 32-bit input register and a byte or half to write the truncated values to. Not using these features would cause excess unneeded register pressure. This PR resolves the inefficiency. It also adds a test for the expected usecase of extending or truncateting a block of 32 values to/from fp4 with a uniform scale to ensure that this usage has a minimal amount of vector shuffling.
The current arith-to-amdgpu patterns for scaling_extf and scaling_truncf don't take full advantage of the native packing ability of the intrinsics being targetted. Scaling extension takes the location of the two elements to be extended as a constant argument (byte for fp4, half for fp8), and scaling truncation takes a 32-bit input register and a byte or half to write the truncated values to.
Not using these features would cause excess unneeded register pressure. This PR resolves the inefficiency.
It also adds a test for the expected usecase of extending or truncateting a block of 32 values to/from fp4 with a uniform scale to ensure that this usage has a minimal amount of vector shuffling.