Skip to content

Reland "[mlir][vector] Use vector.broadcast in place of vector.splat" #150138

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jul 23, 2025

Conversation

newling
Copy link
Contributor

@newling newling commented Jul 22, 2025

This reverts commit 228c45f (PR #148937) . Now that #148027 is landed, I think it is safe to "reland" the original PR: #148028

@llvmbot
Copy link
Member

llvmbot commented Jul 22, 2025

@llvm/pr-subscribers-mlir-nvgpu
@llvm/pr-subscribers-mlir-gpu

@llvm/pr-subscribers-mlir

Author: James Newling (newling)

Changes

….splat (#148937)"

This reverts commit 228c45f.

Now that #148027 I think it safe to "reland" the original PR: #148028


Full diff: https://github.com/llvm/llvm-project/pull/150138.diff

4 Files Affected:

  • (modified) mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp (+1-1)
  • (modified) mlir/lib/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp (+1-1)
  • (modified) mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp (+1-1)
  • (modified) mlir/test/Dialect/NVGPU/transform-matmul-to-nvvm.mlir (+3-3)
diff --git a/mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp b/mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp
index f14264e2f55f3..55b757c136127 100644
--- a/mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp
@@ -123,7 +123,7 @@ void mlir::arith::populateEmulateUnsupportedFloatsLegality(
                                vector::OuterProductOp, vector::ScanOp>(
       [&](Operation *op) { return converter.isLegal(op); });
   target.addLegalOp<arith::BitcastOp, arith::ExtFOp, arith::TruncFOp,
-                    arith::ConstantOp, vector::SplatOp>();
+                    arith::ConstantOp, vector::SplatOp, vector::BroadcastOp>();
 }
 
 void EmulateUnsupportedFloatsPass::runOnOperation() {
diff --git a/mlir/lib/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp b/mlir/lib/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp
index 5d253c1199dc0..f5f0bfa4128aa 100644
--- a/mlir/lib/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp
+++ b/mlir/lib/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp
@@ -688,7 +688,7 @@ Value MmaSyncBuilder::buildMmaSyncMemRefLoadOperand(
 
   Type elementType = getElementTypeOrSelf(memref.getType());
   auto vt = VectorType::get(vectorShape, elementType);
-  Value res = vector::SplatOp::create(b, loc, vt, loads[0]);
+  Value res = vector::BroadcastOp::create(b, loc, vt, loads[0]);
   foreachIndividualVectorElement(
       res,
       /*applyFn=*/
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
index de67098d397f4..0d44415595cb8 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
@@ -438,7 +438,7 @@ struct UnrollCreateDescOp : public UnrollPattern<xegpu::CreateDescOp> {
           Value inc = arith::ConstantIndexOp::create(rewriter, loc,
                                                      i * blockedChunkSize);
           Value incVec =
-              vector::SplatOp::create(rewriter, loc, indiceType, inc);
+              vector::BroadcastOp::create(rewriter, loc, indiceType, inc);
           Value offsetIndice =
               arith::AddIOp::create(rewriter, loc, indice, incVec);
 
diff --git a/mlir/test/Dialect/NVGPU/transform-matmul-to-nvvm.mlir b/mlir/test/Dialect/NVGPU/transform-matmul-to-nvvm.mlir
index 07e03f3b8473d..bbe27fe1b99d9 100644
--- a/mlir/test/Dialect/NVGPU/transform-matmul-to-nvvm.mlir
+++ b/mlir/test/Dialect/NVGPU/transform-matmul-to-nvvm.mlir
@@ -20,14 +20,14 @@ func.func @matmul_16x8x4xf32_global(
 // CHECK:           %[[VAL_7:.*]] = affine.apply #[[$div4p8]]()[%[[TIDX]]]
 // CHECK:           %[[VAL_8:.*]] = affine.apply #[[$mod4]]()[%[[TIDX]]]
 // CHECK:           %[[VAL_9:.*]] = memref.load %[[VAL_0]][%[[VAL_7]], %[[VAL_8]]] : memref<16x4xf32>
-// CHECK:           %[[VAL_10:.*]] = vector.splat %[[VAL_6]] : vector<2x1xf32>
+// CHECK:           %[[VAL_10:.*]] = vector.broadcast %[[VAL_6]] : f32 to vector<2x1xf32>
 // CHECK:           %[[VAL_11:.*]] = vector.insert %[[VAL_6]], %[[VAL_10]] [0, 0] : f32 into vector<2x1xf32>
 // CHECK:           %[[LHS:.*]] = vector.insert %[[VAL_9]], %[[VAL_11]] [1, 0] : f32 into vector<2x1xf32>
 //
 // CHECK:           %[[VAL_13:.*]] = affine.apply #[[$mod4]]()[%[[TIDX]]]
 // CHECK:           %[[VAL_14:.*]] = affine.apply #[[$div4]]()[%[[TIDX]]]
 // CHECK:           %[[VAL_15:.*]] = memref.load %[[VAL_1]][%[[VAL_13]], %[[VAL_14]]] : memref<4x8xf32>
-// CHECK:           %[[VAL_16:.*]] = vector.splat %[[VAL_15]] : vector<1x1xf32>
+// CHECK:           %[[VAL_16:.*]] = vector.broadcast %[[VAL_15]] : f32 to vector<1x1xf32>
 // CHECK:           %[[RHS:.*]] = vector.insert %[[VAL_15]], %[[VAL_16]] [0, 0] : f32 into vector<1x1xf32>
 //
 // CHECK:           %[[VAL_18:.*]] = affine.apply #[[$div4]]()[%[[TIDX]]]
@@ -42,7 +42,7 @@ func.func @matmul_16x8x4xf32_global(
 // CHECK:           %[[VAL_27:.*]] = affine.apply #[[$div4p8]]()[%[[TIDX]]]
 // CHECK:           %[[VAL_28:.*]] = affine.apply #[[$map4]]()[%[[TIDX]]]
 // CHECK:           %[[VAL_29:.*]] = memref.load %[[VAL_2]][%[[VAL_27]], %[[VAL_28]]] : memref<16x8xf32>
-// CHECK:           %[[VAL_30:.*]] = vector.splat %[[VAL_20]] : vector<2x2xf32>
+// CHECK:           %[[VAL_30:.*]] = vector.broadcast %[[VAL_20]] : f32 to vector<2x2xf32>
 // CHECK:           %[[VAL_31:.*]] = vector.insert %[[VAL_20]], %[[VAL_30]] [0, 0] : f32 into vector<2x2xf32>
 // CHECK:           %[[VAL_32:.*]] = vector.insert %[[VAL_23]], %[[VAL_31]] [0, 1] : f32 into vector<2x2xf32>
 // CHECK:           %[[VAL_33:.*]] = vector.insert %[[VAL_26]], %[[VAL_32]] [1, 0] : f32 into vector<2x2xf32>

@kuhar kuhar changed the title Revert "Revert [mlir][vector] Use vector.broadcast in place of vector… Reland "[mlir][vector] Use vector.broadcast in place of vector spalt" Jul 23, 2025
@newling
Copy link
Contributor Author

newling commented Jul 23, 2025

Heads up @cathyzhyi I'm going to land this now (it is the regression you reported to me earlier but should be safe this time)

@newling newling changed the title Reland "[mlir][vector] Use vector.broadcast in place of vector spalt" Reland "[mlir][vector] Use vector.broadcast in place of vector splat" Jul 23, 2025
@newling newling changed the title Reland "[mlir][vector] Use vector.broadcast in place of vector splat" Reland "[mlir][vector] Use vector.broadcast in place of vector.splat" Jul 23, 2025
@newling newling merged commit 6ed921f into llvm:main Jul 23, 2025
14 checks passed
@cathyzhyi
Copy link
Contributor

Heads up @cathyzhyi I'm going to land this now (it is the regression you reported to me earlier but should be safe this time)

Thank you so much!

mahesh-attarde pushed a commit to mahesh-attarde/llvm-project that referenced this pull request Jul 28, 2025
…llvm#150138)

This reverts commit 228c45f (PR
llvm#148937) . Now that llvm#148027 is landed, I think it is safe to "reland"
the original PR: llvm#148028
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants