Skip to content

Commit e4a90e2

Browse files
committed
[mlir][NFC] update mlir/Dialect create APIs (14/n)
See #147168 for more info.
1 parent b7e332d commit e4a90e2

23 files changed

+563
-557
lines changed

mlir/lib/Dialect/AMDGPU/Transforms/EmulateAtomics.cpp

Lines changed: 15 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -99,8 +99,8 @@ static Value flattenVecToBits(ConversionPatternRewriter &rewriter, Location loc,
9999
vectorType.getElementTypeBitWidth() * vectorType.getNumElements();
100100
Type allBitsType = rewriter.getIntegerType(bitwidth);
101101
auto allBitsVecType = VectorType::get({1}, allBitsType);
102-
Value bitcast = rewriter.create<vector::BitCastOp>(loc, allBitsVecType, val);
103-
Value scalar = rewriter.create<vector::ExtractOp>(loc, bitcast, 0);
102+
Value bitcast = vector::BitCastOp::create(rewriter, loc, allBitsVecType, val);
103+
Value scalar = vector::ExtractOp::create(rewriter, loc, bitcast, 0);
104104
return scalar;
105105
}
106106

@@ -118,27 +118,27 @@ LogicalResult RawBufferAtomicByCasPattern<AtomicOp, ArithOp>::matchAndRewrite(
118118

119119
SmallVector<NamedAttribute> loadAttrs;
120120
patchOperandSegmentSizes(origAttrs, loadAttrs, DataArgAction::Drop);
121-
Value initialLoad =
122-
rewriter.create<RawBufferLoadOp>(loc, dataType, invariantArgs, loadAttrs);
121+
Value initialLoad = RawBufferLoadOp::create(rewriter, loc, dataType,
122+
invariantArgs, loadAttrs);
123123
Block *currentBlock = rewriter.getInsertionBlock();
124124
Block *afterAtomic =
125125
rewriter.splitBlock(currentBlock, rewriter.getInsertionPoint());
126126
Block *loopBlock = rewriter.createBlock(afterAtomic, {dataType}, {loc});
127127

128128
rewriter.setInsertionPointToEnd(currentBlock);
129-
rewriter.create<cf::BranchOp>(loc, loopBlock, initialLoad);
129+
cf::BranchOp::create(rewriter, loc, loopBlock, initialLoad);
130130

131131
rewriter.setInsertionPointToEnd(loopBlock);
132132
Value prevLoad = loopBlock->getArgument(0);
133-
Value operated = rewriter.create<ArithOp>(loc, data, prevLoad);
133+
Value operated = ArithOp::create(rewriter, loc, data, prevLoad);
134134
dataType = operated.getType();
135135

136136
SmallVector<NamedAttribute> cmpswapAttrs;
137137
patchOperandSegmentSizes(origAttrs, cmpswapAttrs, DataArgAction::Duplicate);
138138
SmallVector<Value> cmpswapArgs = {operated, prevLoad};
139139
cmpswapArgs.append(invariantArgs.begin(), invariantArgs.end());
140-
Value atomicRes = rewriter.create<RawBufferAtomicCmpswapOp>(
141-
loc, dataType, cmpswapArgs, cmpswapAttrs);
140+
Value atomicRes = RawBufferAtomicCmpswapOp::create(rewriter, loc, dataType,
141+
cmpswapArgs, cmpswapAttrs);
142142

143143
// We care about exact bitwise equality here, so do some bitcasts.
144144
// These will fold away during lowering to the ROCDL dialect, where
@@ -150,14 +150,15 @@ LogicalResult RawBufferAtomicByCasPattern<AtomicOp, ArithOp>::matchAndRewrite(
150150
if (auto floatDataTy = dyn_cast<FloatType>(dataType)) {
151151
Type equivInt = rewriter.getIntegerType(floatDataTy.getWidth());
152152
prevLoadForCompare =
153-
rewriter.create<arith::BitcastOp>(loc, equivInt, prevLoad);
153+
arith::BitcastOp::create(rewriter, loc, equivInt, prevLoad);
154154
atomicResForCompare =
155-
rewriter.create<arith::BitcastOp>(loc, equivInt, atomicRes);
155+
arith::BitcastOp::create(rewriter, loc, equivInt, atomicRes);
156156
}
157-
Value canLeave = rewriter.create<arith::CmpIOp>(
158-
loc, arith::CmpIPredicate::eq, atomicResForCompare, prevLoadForCompare);
159-
rewriter.create<cf::CondBranchOp>(loc, canLeave, afterAtomic, ValueRange{},
160-
loopBlock, atomicRes);
157+
Value canLeave =
158+
arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::eq,
159+
atomicResForCompare, prevLoadForCompare);
160+
cf::CondBranchOp::create(rewriter, loc, canLeave, afterAtomic, ValueRange{},
161+
loopBlock, atomicRes);
161162
rewriter.eraseOp(atomicOp);
162163
return success();
163164
}

mlir/lib/Dialect/AMDGPU/Transforms/MaskedloadToLoad.cpp

Lines changed: 28 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -56,11 +56,11 @@ static Value createVectorLoadForMaskedLoad(OpBuilder &builder, Location loc,
5656
vector::MaskedLoadOp maskedOp,
5757
bool passthru) {
5858
VectorType vectorType = maskedOp.getVectorType();
59-
Value load = builder.create<vector::LoadOp>(
60-
loc, vectorType, maskedOp.getBase(), maskedOp.getIndices());
59+
Value load = vector::LoadOp::create(
60+
builder, loc, vectorType, maskedOp.getBase(), maskedOp.getIndices());
6161
if (passthru)
62-
load = builder.create<arith::SelectOp>(loc, vectorType, maskedOp.getMask(),
63-
load, maskedOp.getPassThru());
62+
load = arith::SelectOp::create(builder, loc, vectorType, maskedOp.getMask(),
63+
load, maskedOp.getPassThru());
6464
return load;
6565
}
6666

@@ -110,7 +110,7 @@ struct MaskedLoadLowering final : OpRewritePattern<vector::MaskedLoadOp> {
110110
SmallVector<OpFoldResult> indices = maskedOp.getIndices();
111111

112112
auto stridedMetadata =
113-
rewriter.create<memref::ExtractStridedMetadataOp>(loc, src);
113+
memref::ExtractStridedMetadataOp::create(rewriter, loc, src);
114114
SmallVector<OpFoldResult> strides =
115115
stridedMetadata.getConstifiedMixedStrides();
116116
SmallVector<OpFoldResult> sizes = stridedMetadata.getConstifiedMixedSizes();
@@ -124,47 +124,47 @@ struct MaskedLoadLowering final : OpRewritePattern<vector::MaskedLoadOp> {
124124

125125
// delta = bufferSize - linearizedOffset
126126
Value vectorSizeOffset =
127-
rewriter.create<arith::ConstantIndexOp>(loc, vectorSize);
127+
arith::ConstantIndexOp::create(rewriter, loc, vectorSize);
128128
Value linearIndex =
129129
getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndices);
130130
Value totalSize = getValueOrCreateConstantIndexOp(
131131
rewriter, loc, linearizedInfo.linearizedSize);
132-
Value delta = rewriter.create<arith::SubIOp>(loc, totalSize, linearIndex);
132+
Value delta = arith::SubIOp::create(rewriter, loc, totalSize, linearIndex);
133133

134134
// 1) check if delta < vectorSize
135-
Value isOutofBounds = rewriter.create<arith::CmpIOp>(
136-
loc, arith::CmpIPredicate::ult, delta, vectorSizeOffset);
135+
Value isOutofBounds = arith::CmpIOp::create(
136+
rewriter, loc, arith::CmpIPredicate::ult, delta, vectorSizeOffset);
137137

138138
// 2) check if (detla % elements_per_word != 0)
139-
Value elementsPerWord = rewriter.create<arith::ConstantIndexOp>(
140-
loc, llvm::divideCeil(32, elementBitWidth));
141-
Value isNotWordAligned = rewriter.create<arith::CmpIOp>(
142-
loc, arith::CmpIPredicate::ne,
143-
rewriter.create<arith::RemUIOp>(loc, delta, elementsPerWord),
144-
rewriter.create<arith::ConstantIndexOp>(loc, 0));
139+
Value elementsPerWord = arith::ConstantIndexOp::create(
140+
rewriter, loc, llvm::divideCeil(32, elementBitWidth));
141+
Value isNotWordAligned = arith::CmpIOp::create(
142+
rewriter, loc, arith::CmpIPredicate::ne,
143+
arith::RemUIOp::create(rewriter, loc, delta, elementsPerWord),
144+
arith::ConstantIndexOp::create(rewriter, loc, 0));
145145

146146
// We take the fallback of maskedload default lowering only it is both
147147
// out-of-bounds and not word aligned. The fallback ensures correct results
148148
// when loading at the boundary of the buffer since buffer load returns
149149
// inconsistent zeros for the whole word when boundary is crossed.
150150
Value ifCondition =
151-
rewriter.create<arith::AndIOp>(loc, isOutofBounds, isNotWordAligned);
151+
arith::AndIOp::create(rewriter, loc, isOutofBounds, isNotWordAligned);
152152

153153
auto thenBuilder = [&](OpBuilder &builder, Location loc) {
154154
Operation *read = builder.clone(*maskedOp.getOperation());
155155
read->setAttr(kMaskedloadNeedsMask, builder.getUnitAttr());
156156
Value readResult = read->getResult(0);
157-
builder.create<scf::YieldOp>(loc, readResult);
157+
scf::YieldOp::create(builder, loc, readResult);
158158
};
159159

160160
auto elseBuilder = [&](OpBuilder &builder, Location loc) {
161161
Value res = createVectorLoadForMaskedLoad(builder, loc, maskedOp,
162162
/*passthru=*/true);
163-
rewriter.create<scf::YieldOp>(loc, res);
163+
scf::YieldOp::create(rewriter, loc, res);
164164
};
165165

166166
auto ifOp =
167-
rewriter.create<scf::IfOp>(loc, ifCondition, thenBuilder, elseBuilder);
167+
scf::IfOp::create(rewriter, loc, ifCondition, thenBuilder, elseBuilder);
168168

169169
rewriter.replaceOp(maskedOp, ifOp);
170170

@@ -187,13 +187,13 @@ struct FullMaskedLoadToConditionalLoad
187187
auto trueBuilder = [&](OpBuilder &builder, Location loc) {
188188
Value res = createVectorLoadForMaskedLoad(builder, loc, loadOp,
189189
/*passthru=*/false);
190-
rewriter.create<scf::YieldOp>(loc, res);
190+
scf::YieldOp::create(rewriter, loc, res);
191191
};
192192
auto falseBuilder = [&](OpBuilder &builder, Location loc) {
193-
rewriter.create<scf::YieldOp>(loc, loadOp.getPassThru());
193+
scf::YieldOp::create(rewriter, loc, loadOp.getPassThru());
194194
};
195-
auto ifOp = rewriter.create<scf::IfOp>(loadOp.getLoc(), cond, trueBuilder,
196-
falseBuilder);
195+
auto ifOp = scf::IfOp::create(rewriter, loadOp.getLoc(), cond, trueBuilder,
196+
falseBuilder);
197197
rewriter.replaceOp(loadOp, ifOp);
198198
return success();
199199
}
@@ -212,11 +212,12 @@ struct FullMaskedStoreToConditionalStore
212212
Value cond = maybeCond.value();
213213

214214
auto trueBuilder = [&](OpBuilder &builder, Location loc) {
215-
rewriter.create<vector::StoreOp>(loc, storeOp.getValueToStore(),
216-
storeOp.getBase(), storeOp.getIndices());
217-
rewriter.create<scf::YieldOp>(loc);
215+
vector::StoreOp::create(rewriter, loc, storeOp.getValueToStore(),
216+
storeOp.getBase(), storeOp.getIndices());
217+
scf::YieldOp::create(rewriter, loc);
218218
};
219-
auto ifOp = rewriter.create<scf::IfOp>(storeOp.getLoc(), cond, trueBuilder);
219+
auto ifOp =
220+
scf::IfOp::create(rewriter, storeOp.getLoc(), cond, trueBuilder);
220221
rewriter.replaceOp(storeOp, ifOp);
221222
return success();
222223
}

mlir/lib/Dialect/AMDGPU/Transforms/ResolveStridedMetadata.cpp

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -37,8 +37,8 @@ struct ExtractStridedMetadataOnFatRawBufferCastFolder final
3737
return rewriter.notifyMatchFailure(metadataOp,
3838
"not a fat raw buffer cast");
3939
Location loc = castOp.getLoc();
40-
auto sourceMetadata = rewriter.create<memref::ExtractStridedMetadataOp>(
41-
loc, castOp.getSource());
40+
auto sourceMetadata = memref::ExtractStridedMetadataOp::create(
41+
rewriter, loc, castOp.getSource());
4242
SmallVector<Value> results;
4343
if (metadataOp.getBaseBuffer().use_empty()) {
4444
results.push_back(nullptr);
@@ -48,13 +48,13 @@ struct ExtractStridedMetadataOnFatRawBufferCastFolder final
4848
if (baseBufferType == castOp.getResult().getType()) {
4949
results.push_back(castOp.getResult());
5050
} else {
51-
results.push_back(rewriter.create<memref::ReinterpretCastOp>(
52-
loc, baseBufferType, castOp.getResult(), /*offset=*/0,
51+
results.push_back(memref::ReinterpretCastOp::create(
52+
rewriter, loc, baseBufferType, castOp.getResult(), /*offset=*/0,
5353
/*sizes=*/ArrayRef<int64_t>{}, /*strides=*/ArrayRef<int64_t>{}));
5454
}
5555
}
5656
if (castOp.getResetOffset())
57-
results.push_back(rewriter.create<arith::ConstantIndexOp>(loc, 0));
57+
results.push_back(arith::ConstantIndexOp::create(rewriter, loc, 0));
5858
else
5959
results.push_back(sourceMetadata.getOffset());
6060
llvm::append_range(results, sourceMetadata.getSizes());

mlir/lib/Dialect/AMX/IR/AMXDialect.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -76,8 +76,8 @@ static SmallVector<Value> getTileSizes(Location loc, amx::TileType tType,
7676
auto mattr = rewriter.getI16IntegerAttr(tType.getDimSize(0));
7777
auto nattr = rewriter.getI16IntegerAttr(tType.getDimSize(1) * bytes);
7878
return SmallVector<Value>{
79-
rewriter.create<LLVM::ConstantOp>(loc, llvmInt16Type, mattr),
80-
rewriter.create<LLVM::ConstantOp>(loc, llvmInt16Type, nattr)};
79+
LLVM::ConstantOp::create(rewriter, loc, llvmInt16Type, mattr),
80+
LLVM::ConstantOp::create(rewriter, loc, llvmInt16Type, nattr)};
8181
}
8282

8383
/// Maps the 2-dim memref shape to the 64-bit stride. Note that the buffer
@@ -95,15 +95,15 @@ static Value getStride(Location loc, MemRefType mType, Value base,
9595
// Dynamic stride needs code to compute the stride at runtime.
9696
MemRefDescriptor memrefDescriptor(base);
9797
auto attr = rewriter.getI64IntegerAttr(bytes);
98-
Value scale = rewriter.create<LLVM::ConstantOp>(loc, llvmInt64Type, attr);
98+
Value scale = LLVM::ConstantOp::create(rewriter, loc, llvmInt64Type, attr);
9999
return rewriter
100100
.create<LLVM::MulOp>(loc, llvmInt64Type, scale,
101101
memrefDescriptor.stride(rewriter, loc, preLast))
102102
.getResult();
103103
}
104104
// Use direct constant for static stride.
105105
auto attr = rewriter.getI64IntegerAttr(strides[preLast] * bytes);
106-
return rewriter.create<LLVM::ConstantOp>(loc, llvmInt64Type, attr)
106+
return LLVM::ConstantOp::create(rewriter, loc, llvmInt64Type, attr)
107107
.getResult();
108108
}
109109

mlir/lib/Dialect/Affine/Transforms/AffineDataCopyGeneration.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -202,7 +202,7 @@ void AffineDataCopyGeneration::runOnBlock(Block *block,
202202
void AffineDataCopyGeneration::runOnOperation() {
203203
func::FuncOp f = getOperation();
204204
OpBuilder topBuilder(f.getBody());
205-
zeroIndex = topBuilder.create<arith::ConstantIndexOp>(f.getLoc(), 0);
205+
zeroIndex = arith::ConstantIndexOp::create(topBuilder, f.getLoc(), 0);
206206

207207
// Nests that are copy-in's or copy-out's; the root AffineForOps of those
208208
// nests are stored herein.

mlir/lib/Dialect/Affine/Transforms/AffineExpandIndexOps.cpp

Lines changed: 17 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -58,8 +58,9 @@ static SmallVector<Value> computeStrides(Location loc, RewriterBase &rewriter,
5858
// Note: basis elements and their products are, definitionally,
5959
// non-negative, so `nuw` is justified.
6060
if (dynamicPart)
61-
dynamicPart = rewriter.create<arith::MulIOp>(
62-
loc, dynamicPart, dynamicBasis[dynamicIndex - 1], ovflags);
61+
dynamicPart =
62+
arith::MulIOp::create(rewriter, loc, dynamicPart,
63+
dynamicBasis[dynamicIndex - 1], ovflags);
6364
else
6465
dynamicPart = dynamicBasis[dynamicIndex - 1];
6566
--dynamicIndex;
@@ -74,7 +75,7 @@ static SmallVector<Value> computeStrides(Location loc, RewriterBase &rewriter,
7475
rewriter.createOrFold<arith::ConstantIndexOp>(loc, staticPart);
7576
if (dynamicPart)
7677
stride =
77-
rewriter.create<arith::MulIOp>(loc, dynamicPart, stride, ovflags);
78+
arith::MulIOp::create(rewriter, loc, dynamicPart, stride, ovflags);
7879
result.push_back(stride);
7980
}
8081
}
@@ -106,20 +107,20 @@ affine::lowerAffineDelinearizeIndexOp(RewriterBase &rewriter,
106107
Value zero = rewriter.createOrFold<arith::ConstantIndexOp>(loc, 0);
107108

108109
Value initialPart =
109-
rewriter.create<arith::FloorDivSIOp>(loc, linearIdx, strides.front());
110+
arith::FloorDivSIOp::create(rewriter, loc, linearIdx, strides.front());
110111
results.push_back(initialPart);
111112

112113
auto emitModTerm = [&](Value stride) -> Value {
113-
Value remainder = rewriter.create<arith::RemSIOp>(loc, linearIdx, stride);
114-
Value remainderNegative = rewriter.create<arith::CmpIOp>(
115-
loc, arith::CmpIPredicate::slt, remainder, zero);
114+
Value remainder = arith::RemSIOp::create(rewriter, loc, linearIdx, stride);
115+
Value remainderNegative = arith::CmpIOp::create(
116+
rewriter, loc, arith::CmpIPredicate::slt, remainder, zero);
116117
// If the correction is relevant, this term is <= stride, which is known
117118
// to be positive in `index`. Otherwise, while 2 * stride might overflow,
118119
// this branch won't be taken, so the risk of `poison` is fine.
119-
Value corrected = rewriter.create<arith::AddIOp>(
120-
loc, remainder, stride, arith::IntegerOverflowFlags::nsw);
121-
Value mod = rewriter.create<arith::SelectOp>(loc, remainderNegative,
122-
corrected, remainder);
120+
Value corrected = arith::AddIOp::create(rewriter, loc, remainder, stride,
121+
arith::IntegerOverflowFlags::nsw);
122+
Value mod = arith::SelectOp::create(rewriter, loc, remainderNegative,
123+
corrected, remainder);
123124
return mod;
124125
};
125126

@@ -131,7 +132,7 @@ affine::lowerAffineDelinearizeIndexOp(RewriterBase &rewriter,
131132
// We know both inputs are positive, so floorDiv == div.
132133
// This could potentially be a divui, but it's not clear if that would
133134
// cause issues.
134-
Value divided = rewriter.create<arith::DivSIOp>(loc, modulus, nextStride);
135+
Value divided = arith::DivSIOp::create(rewriter, loc, modulus, nextStride);
135136
results.push_back(divided);
136137
}
137138

@@ -167,8 +168,8 @@ LogicalResult affine::lowerAffineLinearizeIndexOp(RewriterBase &rewriter,
167168
// our hands on an `OpOperand&` for the loop invariant counting function.
168169
for (auto [stride, idxOp] :
169170
llvm::zip_equal(strides, llvm::drop_end(op.getMultiIndexMutable()))) {
170-
Value scaledIdx = rewriter.create<arith::MulIOp>(
171-
loc, idxOp.get(), stride, arith::IntegerOverflowFlags::nsw);
171+
Value scaledIdx = arith::MulIOp::create(rewriter, loc, idxOp.get(), stride,
172+
arith::IntegerOverflowFlags::nsw);
172173
int64_t numHoistableLoops = numEnclosingInvariantLoops(idxOp);
173174
scaledValues.emplace_back(scaledIdx, numHoistableLoops);
174175
}
@@ -184,8 +185,8 @@ LogicalResult affine::lowerAffineLinearizeIndexOp(RewriterBase &rewriter,
184185
Value result = scaledValues.front().first;
185186
for (auto [scaledValue, numHoistableLoops] : llvm::drop_begin(scaledValues)) {
186187
std::ignore = numHoistableLoops;
187-
result = rewriter.create<arith::AddIOp>(loc, result, scaledValue,
188-
arith::IntegerOverflowFlags::nsw);
188+
result = arith::AddIOp::create(rewriter, loc, result, scaledValue,
189+
arith::IntegerOverflowFlags::nsw);
189190
}
190191
rewriter.replaceOp(op, result);
191192
return success();

mlir/lib/Dialect/Affine/Transforms/DecomposeAffineOps.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -88,8 +88,8 @@ static AffineApplyOp createSubApply(RewriterBase &rewriter,
8888
auto rhsMap = AffineMap::get(m.getNumDims(), m.getNumSymbols(), expr, ctx);
8989
SmallVector<Value> rhsOperands = originalOp->getOperands();
9090
canonicalizeMapAndOperands(&rhsMap, &rhsOperands);
91-
return rewriter.create<AffineApplyOp>(originalOp.getLoc(), rhsMap,
92-
rhsOperands);
91+
return AffineApplyOp::create(rewriter, originalOp.getLoc(), rhsMap,
92+
rhsOperands);
9393
}
9494

9595
FailureOr<AffineApplyOp> mlir::affine::decompose(RewriterBase &rewriter,
@@ -160,8 +160,8 @@ FailureOr<AffineApplyOp> mlir::affine::decompose(RewriterBase &rewriter,
160160
auto current = createSubApply(rewriter, op, subExpressions[0]);
161161
for (int64_t i = 1, e = subExpressions.size(); i < e; ++i) {
162162
Value tmp = createSubApply(rewriter, op, subExpressions[i]);
163-
current = rewriter.create<AffineApplyOp>(op.getLoc(), binMap,
164-
ValueRange{current, tmp});
163+
current = AffineApplyOp::create(rewriter, op.getLoc(), binMap,
164+
ValueRange{current, tmp});
165165
LLVM_DEBUG(DBGS() << "--reassociate into: " << current << "\n");
166166
}
167167

mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -424,7 +424,7 @@ static Value createPrivateMemRef(AffineForOp forOp,
424424
// consumer loop nests to reduce their live range. Currently they are added
425425
// at the beginning of the block, because loop nests can be reordered
426426
// during the fusion pass.
427-
Value newMemRef = top.create<memref::AllocOp>(forOp.getLoc(), newMemRefType);
427+
Value newMemRef = memref::AllocOp::create(top, forOp.getLoc(), newMemRefType);
428428

429429
// Build an AffineMap to remap access functions based on lower bound offsets.
430430
SmallVector<AffineExpr, 4> remapExprs;

0 commit comments

Comments
 (0)