diff --git a/mlir/lib/Dialect/AMDGPU/Transforms/EmulateAtomics.cpp b/mlir/lib/Dialect/AMDGPU/Transforms/EmulateAtomics.cpp index 37e0d2af55fe1..6d1f64e94df15 100644 --- a/mlir/lib/Dialect/AMDGPU/Transforms/EmulateAtomics.cpp +++ b/mlir/lib/Dialect/AMDGPU/Transforms/EmulateAtomics.cpp @@ -99,8 +99,8 @@ static Value flattenVecToBits(ConversionPatternRewriter &rewriter, Location loc, vectorType.getElementTypeBitWidth() * vectorType.getNumElements(); Type allBitsType = rewriter.getIntegerType(bitwidth); auto allBitsVecType = VectorType::get({1}, allBitsType); - Value bitcast = rewriter.create(loc, allBitsVecType, val); - Value scalar = rewriter.create(loc, bitcast, 0); + Value bitcast = vector::BitCastOp::create(rewriter, loc, allBitsVecType, val); + Value scalar = vector::ExtractOp::create(rewriter, loc, bitcast, 0); return scalar; } @@ -118,27 +118,27 @@ LogicalResult RawBufferAtomicByCasPattern::matchAndRewrite( SmallVector loadAttrs; patchOperandSegmentSizes(origAttrs, loadAttrs, DataArgAction::Drop); - Value initialLoad = - rewriter.create(loc, dataType, invariantArgs, loadAttrs); + Value initialLoad = RawBufferLoadOp::create(rewriter, loc, dataType, + invariantArgs, loadAttrs); Block *currentBlock = rewriter.getInsertionBlock(); Block *afterAtomic = rewriter.splitBlock(currentBlock, rewriter.getInsertionPoint()); Block *loopBlock = rewriter.createBlock(afterAtomic, {dataType}, {loc}); rewriter.setInsertionPointToEnd(currentBlock); - rewriter.create(loc, loopBlock, initialLoad); + cf::BranchOp::create(rewriter, loc, loopBlock, initialLoad); rewriter.setInsertionPointToEnd(loopBlock); Value prevLoad = loopBlock->getArgument(0); - Value operated = rewriter.create(loc, data, prevLoad); + Value operated = ArithOp::create(rewriter, loc, data, prevLoad); dataType = operated.getType(); SmallVector cmpswapAttrs; patchOperandSegmentSizes(origAttrs, cmpswapAttrs, DataArgAction::Duplicate); SmallVector cmpswapArgs = {operated, prevLoad}; cmpswapArgs.append(invariantArgs.begin(), invariantArgs.end()); - Value atomicRes = rewriter.create( - loc, dataType, cmpswapArgs, cmpswapAttrs); + Value atomicRes = RawBufferAtomicCmpswapOp::create(rewriter, loc, dataType, + cmpswapArgs, cmpswapAttrs); // We care about exact bitwise equality here, so do some bitcasts. // These will fold away during lowering to the ROCDL dialect, where @@ -150,14 +150,15 @@ LogicalResult RawBufferAtomicByCasPattern::matchAndRewrite( if (auto floatDataTy = dyn_cast(dataType)) { Type equivInt = rewriter.getIntegerType(floatDataTy.getWidth()); prevLoadForCompare = - rewriter.create(loc, equivInt, prevLoad); + arith::BitcastOp::create(rewriter, loc, equivInt, prevLoad); atomicResForCompare = - rewriter.create(loc, equivInt, atomicRes); + arith::BitcastOp::create(rewriter, loc, equivInt, atomicRes); } - Value canLeave = rewriter.create( - loc, arith::CmpIPredicate::eq, atomicResForCompare, prevLoadForCompare); - rewriter.create(loc, canLeave, afterAtomic, ValueRange{}, - loopBlock, atomicRes); + Value canLeave = + arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::eq, + atomicResForCompare, prevLoadForCompare); + cf::CondBranchOp::create(rewriter, loc, canLeave, afterAtomic, ValueRange{}, + loopBlock, atomicRes); rewriter.eraseOp(atomicOp); return success(); } diff --git a/mlir/lib/Dialect/AMDGPU/Transforms/MaskedloadToLoad.cpp b/mlir/lib/Dialect/AMDGPU/Transforms/MaskedloadToLoad.cpp index af8634c692654..f15c63c166e0a 100644 --- a/mlir/lib/Dialect/AMDGPU/Transforms/MaskedloadToLoad.cpp +++ b/mlir/lib/Dialect/AMDGPU/Transforms/MaskedloadToLoad.cpp @@ -54,11 +54,11 @@ static Value createVectorLoadForMaskedLoad(OpBuilder &builder, Location loc, vector::MaskedLoadOp maskedOp, bool passthru) { VectorType vectorType = maskedOp.getVectorType(); - Value load = builder.create( - loc, vectorType, maskedOp.getBase(), maskedOp.getIndices()); + Value load = vector::LoadOp::create( + builder, loc, vectorType, maskedOp.getBase(), maskedOp.getIndices()); if (passthru) - load = builder.create(loc, vectorType, maskedOp.getMask(), - load, maskedOp.getPassThru()); + load = arith::SelectOp::create(builder, loc, vectorType, maskedOp.getMask(), + load, maskedOp.getPassThru()); return load; } @@ -108,7 +108,7 @@ struct MaskedLoadLowering final : OpRewritePattern { SmallVector indices = maskedOp.getIndices(); auto stridedMetadata = - rewriter.create(loc, src); + memref::ExtractStridedMetadataOp::create(rewriter, loc, src); SmallVector strides = stridedMetadata.getConstifiedMixedStrides(); SmallVector sizes = stridedMetadata.getConstifiedMixedSizes(); @@ -122,47 +122,47 @@ struct MaskedLoadLowering final : OpRewritePattern { // delta = bufferSize - linearizedOffset Value vectorSizeOffset = - rewriter.create(loc, vectorSize); + arith::ConstantIndexOp::create(rewriter, loc, vectorSize); Value linearIndex = getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndices); Value totalSize = getValueOrCreateConstantIndexOp( rewriter, loc, linearizedInfo.linearizedSize); - Value delta = rewriter.create(loc, totalSize, linearIndex); + Value delta = arith::SubIOp::create(rewriter, loc, totalSize, linearIndex); // 1) check if delta < vectorSize - Value isOutofBounds = rewriter.create( - loc, arith::CmpIPredicate::ult, delta, vectorSizeOffset); + Value isOutofBounds = arith::CmpIOp::create( + rewriter, loc, arith::CmpIPredicate::ult, delta, vectorSizeOffset); // 2) check if (detla % elements_per_word != 0) - Value elementsPerWord = rewriter.create( - loc, llvm::divideCeil(32, elementBitWidth)); - Value isNotWordAligned = rewriter.create( - loc, arith::CmpIPredicate::ne, - rewriter.create(loc, delta, elementsPerWord), - rewriter.create(loc, 0)); + Value elementsPerWord = arith::ConstantIndexOp::create( + rewriter, loc, llvm::divideCeil(32, elementBitWidth)); + Value isNotWordAligned = arith::CmpIOp::create( + rewriter, loc, arith::CmpIPredicate::ne, + arith::RemUIOp::create(rewriter, loc, delta, elementsPerWord), + arith::ConstantIndexOp::create(rewriter, loc, 0)); // We take the fallback of maskedload default lowering only it is both // out-of-bounds and not word aligned. The fallback ensures correct results // when loading at the boundary of the buffer since buffer load returns // inconsistent zeros for the whole word when boundary is crossed. Value ifCondition = - rewriter.create(loc, isOutofBounds, isNotWordAligned); + arith::AndIOp::create(rewriter, loc, isOutofBounds, isNotWordAligned); auto thenBuilder = [&](OpBuilder &builder, Location loc) { Operation *read = builder.clone(*maskedOp.getOperation()); read->setAttr(kMaskedloadNeedsMask, builder.getUnitAttr()); Value readResult = read->getResult(0); - builder.create(loc, readResult); + scf::YieldOp::create(builder, loc, readResult); }; auto elseBuilder = [&](OpBuilder &builder, Location loc) { Value res = createVectorLoadForMaskedLoad(builder, loc, maskedOp, /*passthru=*/true); - rewriter.create(loc, res); + scf::YieldOp::create(rewriter, loc, res); }; auto ifOp = - rewriter.create(loc, ifCondition, thenBuilder, elseBuilder); + scf::IfOp::create(rewriter, loc, ifCondition, thenBuilder, elseBuilder); rewriter.replaceOp(maskedOp, ifOp); @@ -185,13 +185,13 @@ struct FullMaskedLoadToConditionalLoad auto trueBuilder = [&](OpBuilder &builder, Location loc) { Value res = createVectorLoadForMaskedLoad(builder, loc, loadOp, /*passthru=*/false); - rewriter.create(loc, res); + scf::YieldOp::create(rewriter, loc, res); }; auto falseBuilder = [&](OpBuilder &builder, Location loc) { - rewriter.create(loc, loadOp.getPassThru()); + scf::YieldOp::create(rewriter, loc, loadOp.getPassThru()); }; - auto ifOp = rewriter.create(loadOp.getLoc(), cond, trueBuilder, - falseBuilder); + auto ifOp = scf::IfOp::create(rewriter, loadOp.getLoc(), cond, trueBuilder, + falseBuilder); rewriter.replaceOp(loadOp, ifOp); return success(); } @@ -210,11 +210,12 @@ struct FullMaskedStoreToConditionalStore Value cond = maybeCond.value(); auto trueBuilder = [&](OpBuilder &builder, Location loc) { - rewriter.create(loc, storeOp.getValueToStore(), - storeOp.getBase(), storeOp.getIndices()); - rewriter.create(loc); + vector::StoreOp::create(rewriter, loc, storeOp.getValueToStore(), + storeOp.getBase(), storeOp.getIndices()); + scf::YieldOp::create(rewriter, loc); }; - auto ifOp = rewriter.create(storeOp.getLoc(), cond, trueBuilder); + auto ifOp = + scf::IfOp::create(rewriter, storeOp.getLoc(), cond, trueBuilder); rewriter.replaceOp(storeOp, ifOp); return success(); } diff --git a/mlir/lib/Dialect/AMDGPU/Transforms/ResolveStridedMetadata.cpp b/mlir/lib/Dialect/AMDGPU/Transforms/ResolveStridedMetadata.cpp index 195f59d625554..f8bab8289cbc6 100644 --- a/mlir/lib/Dialect/AMDGPU/Transforms/ResolveStridedMetadata.cpp +++ b/mlir/lib/Dialect/AMDGPU/Transforms/ResolveStridedMetadata.cpp @@ -37,8 +37,8 @@ struct ExtractStridedMetadataOnFatRawBufferCastFolder final return rewriter.notifyMatchFailure(metadataOp, "not a fat raw buffer cast"); Location loc = castOp.getLoc(); - auto sourceMetadata = rewriter.create( - loc, castOp.getSource()); + auto sourceMetadata = memref::ExtractStridedMetadataOp::create( + rewriter, loc, castOp.getSource()); SmallVector results; if (metadataOp.getBaseBuffer().use_empty()) { results.push_back(nullptr); @@ -48,13 +48,13 @@ struct ExtractStridedMetadataOnFatRawBufferCastFolder final if (baseBufferType == castOp.getResult().getType()) { results.push_back(castOp.getResult()); } else { - results.push_back(rewriter.create( - loc, baseBufferType, castOp.getResult(), /*offset=*/0, + results.push_back(memref::ReinterpretCastOp::create( + rewriter, loc, baseBufferType, castOp.getResult(), /*offset=*/0, /*sizes=*/ArrayRef{}, /*strides=*/ArrayRef{})); } } if (castOp.getResetOffset()) - results.push_back(rewriter.create(loc, 0)); + results.push_back(arith::ConstantIndexOp::create(rewriter, loc, 0)); else results.push_back(sourceMetadata.getOffset()); llvm::append_range(results, sourceMetadata.getSizes()); diff --git a/mlir/lib/Dialect/AMX/IR/AMXDialect.cpp b/mlir/lib/Dialect/AMX/IR/AMXDialect.cpp index 12b375b373fa9..748ff1edbfeb2 100644 --- a/mlir/lib/Dialect/AMX/IR/AMXDialect.cpp +++ b/mlir/lib/Dialect/AMX/IR/AMXDialect.cpp @@ -76,8 +76,8 @@ static SmallVector getTileSizes(Location loc, amx::TileType tType, auto mattr = rewriter.getI16IntegerAttr(tType.getDimSize(0)); auto nattr = rewriter.getI16IntegerAttr(tType.getDimSize(1) * bytes); return SmallVector{ - rewriter.create(loc, llvmInt16Type, mattr), - rewriter.create(loc, llvmInt16Type, nattr)}; + LLVM::ConstantOp::create(rewriter, loc, llvmInt16Type, mattr), + LLVM::ConstantOp::create(rewriter, loc, llvmInt16Type, nattr)}; } /// Maps the 2-dim memref shape to the 64-bit stride. Note that the buffer @@ -95,7 +95,7 @@ static Value getStride(Location loc, MemRefType mType, Value base, // Dynamic stride needs code to compute the stride at runtime. MemRefDescriptor memrefDescriptor(base); auto attr = rewriter.getI64IntegerAttr(bytes); - Value scale = rewriter.create(loc, llvmInt64Type, attr); + Value scale = LLVM::ConstantOp::create(rewriter, loc, llvmInt64Type, attr); return rewriter .create(loc, llvmInt64Type, scale, memrefDescriptor.stride(rewriter, loc, preLast)) @@ -103,7 +103,7 @@ static Value getStride(Location loc, MemRefType mType, Value base, } // Use direct constant for static stride. auto attr = rewriter.getI64IntegerAttr(strides[preLast] * bytes); - return rewriter.create(loc, llvmInt64Type, attr) + return LLVM::ConstantOp::create(rewriter, loc, llvmInt64Type, attr) .getResult(); } diff --git a/mlir/lib/Dialect/Affine/Transforms/AffineDataCopyGeneration.cpp b/mlir/lib/Dialect/Affine/Transforms/AffineDataCopyGeneration.cpp index f18cec5a14fae..df39544aeaa09 100644 --- a/mlir/lib/Dialect/Affine/Transforms/AffineDataCopyGeneration.cpp +++ b/mlir/lib/Dialect/Affine/Transforms/AffineDataCopyGeneration.cpp @@ -202,7 +202,7 @@ void AffineDataCopyGeneration::runOnBlock(Block *block, void AffineDataCopyGeneration::runOnOperation() { func::FuncOp f = getOperation(); OpBuilder topBuilder(f.getBody()); - zeroIndex = topBuilder.create(f.getLoc(), 0); + zeroIndex = arith::ConstantIndexOp::create(topBuilder, f.getLoc(), 0); // Nests that are copy-in's or copy-out's; the root AffineForOps of those // nests are stored herein. diff --git a/mlir/lib/Dialect/Affine/Transforms/AffineExpandIndexOps.cpp b/mlir/lib/Dialect/Affine/Transforms/AffineExpandIndexOps.cpp index 5430bdc4ff858..c0d174a04abf9 100644 --- a/mlir/lib/Dialect/Affine/Transforms/AffineExpandIndexOps.cpp +++ b/mlir/lib/Dialect/Affine/Transforms/AffineExpandIndexOps.cpp @@ -58,8 +58,9 @@ static SmallVector computeStrides(Location loc, RewriterBase &rewriter, // Note: basis elements and their products are, definitionally, // non-negative, so `nuw` is justified. if (dynamicPart) - dynamicPart = rewriter.create( - loc, dynamicPart, dynamicBasis[dynamicIndex - 1], ovflags); + dynamicPart = + arith::MulIOp::create(rewriter, loc, dynamicPart, + dynamicBasis[dynamicIndex - 1], ovflags); else dynamicPart = dynamicBasis[dynamicIndex - 1]; --dynamicIndex; @@ -74,7 +75,7 @@ static SmallVector computeStrides(Location loc, RewriterBase &rewriter, rewriter.createOrFold(loc, staticPart); if (dynamicPart) stride = - rewriter.create(loc, dynamicPart, stride, ovflags); + arith::MulIOp::create(rewriter, loc, dynamicPart, stride, ovflags); result.push_back(stride); } } @@ -106,20 +107,20 @@ affine::lowerAffineDelinearizeIndexOp(RewriterBase &rewriter, Value zero = rewriter.createOrFold(loc, 0); Value initialPart = - rewriter.create(loc, linearIdx, strides.front()); + arith::FloorDivSIOp::create(rewriter, loc, linearIdx, strides.front()); results.push_back(initialPart); auto emitModTerm = [&](Value stride) -> Value { - Value remainder = rewriter.create(loc, linearIdx, stride); - Value remainderNegative = rewriter.create( - loc, arith::CmpIPredicate::slt, remainder, zero); + Value remainder = arith::RemSIOp::create(rewriter, loc, linearIdx, stride); + Value remainderNegative = arith::CmpIOp::create( + rewriter, loc, arith::CmpIPredicate::slt, remainder, zero); // If the correction is relevant, this term is <= stride, which is known // to be positive in `index`. Otherwise, while 2 * stride might overflow, // this branch won't be taken, so the risk of `poison` is fine. - Value corrected = rewriter.create( - loc, remainder, stride, arith::IntegerOverflowFlags::nsw); - Value mod = rewriter.create(loc, remainderNegative, - corrected, remainder); + Value corrected = arith::AddIOp::create(rewriter, loc, remainder, stride, + arith::IntegerOverflowFlags::nsw); + Value mod = arith::SelectOp::create(rewriter, loc, remainderNegative, + corrected, remainder); return mod; }; @@ -131,7 +132,7 @@ affine::lowerAffineDelinearizeIndexOp(RewriterBase &rewriter, // We know both inputs are positive, so floorDiv == div. // This could potentially be a divui, but it's not clear if that would // cause issues. - Value divided = rewriter.create(loc, modulus, nextStride); + Value divided = arith::DivSIOp::create(rewriter, loc, modulus, nextStride); results.push_back(divided); } @@ -167,8 +168,8 @@ LogicalResult affine::lowerAffineLinearizeIndexOp(RewriterBase &rewriter, // our hands on an `OpOperand&` for the loop invariant counting function. for (auto [stride, idxOp] : llvm::zip_equal(strides, llvm::drop_end(op.getMultiIndexMutable()))) { - Value scaledIdx = rewriter.create( - loc, idxOp.get(), stride, arith::IntegerOverflowFlags::nsw); + Value scaledIdx = arith::MulIOp::create(rewriter, loc, idxOp.get(), stride, + arith::IntegerOverflowFlags::nsw); int64_t numHoistableLoops = numEnclosingInvariantLoops(idxOp); scaledValues.emplace_back(scaledIdx, numHoistableLoops); } @@ -184,8 +185,8 @@ LogicalResult affine::lowerAffineLinearizeIndexOp(RewriterBase &rewriter, Value result = scaledValues.front().first; for (auto [scaledValue, numHoistableLoops] : llvm::drop_begin(scaledValues)) { std::ignore = numHoistableLoops; - result = rewriter.create(loc, result, scaledValue, - arith::IntegerOverflowFlags::nsw); + result = arith::AddIOp::create(rewriter, loc, result, scaledValue, + arith::IntegerOverflowFlags::nsw); } rewriter.replaceOp(op, result); return success(); diff --git a/mlir/lib/Dialect/Affine/Transforms/DecomposeAffineOps.cpp b/mlir/lib/Dialect/Affine/Transforms/DecomposeAffineOps.cpp index 4fd0cf9b3cd25..3c00b323473d2 100644 --- a/mlir/lib/Dialect/Affine/Transforms/DecomposeAffineOps.cpp +++ b/mlir/lib/Dialect/Affine/Transforms/DecomposeAffineOps.cpp @@ -88,8 +88,8 @@ static AffineApplyOp createSubApply(RewriterBase &rewriter, auto rhsMap = AffineMap::get(m.getNumDims(), m.getNumSymbols(), expr, ctx); SmallVector rhsOperands = originalOp->getOperands(); canonicalizeMapAndOperands(&rhsMap, &rhsOperands); - return rewriter.create(originalOp.getLoc(), rhsMap, - rhsOperands); + return AffineApplyOp::create(rewriter, originalOp.getLoc(), rhsMap, + rhsOperands); } FailureOr mlir::affine::decompose(RewriterBase &rewriter, @@ -160,8 +160,8 @@ FailureOr mlir::affine::decompose(RewriterBase &rewriter, auto current = createSubApply(rewriter, op, subExpressions[0]); for (int64_t i = 1, e = subExpressions.size(); i < e; ++i) { Value tmp = createSubApply(rewriter, op, subExpressions[i]); - current = rewriter.create(op.getLoc(), binMap, - ValueRange{current, tmp}); + current = AffineApplyOp::create(rewriter, op.getLoc(), binMap, + ValueRange{current, tmp}); LLVM_DEBUG(DBGS() << "--reassociate into: " << current << "\n"); } diff --git a/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp b/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp index 1d5a665bf6bb1..6c9adff7e9106 100644 --- a/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp +++ b/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp @@ -424,7 +424,7 @@ static Value createPrivateMemRef(AffineForOp forOp, // consumer loop nests to reduce their live range. Currently they are added // at the beginning of the block, because loop nests can be reordered // during the fusion pass. - Value newMemRef = top.create(forOp.getLoc(), newMemRefType); + Value newMemRef = memref::AllocOp::create(top, forOp.getLoc(), newMemRefType); // Build an AffineMap to remap access functions based on lower bound offsets. SmallVector remapExprs; diff --git a/mlir/lib/Dialect/Affine/Transforms/PipelineDataTransfer.cpp b/mlir/lib/Dialect/Affine/Transforms/PipelineDataTransfer.cpp index 05a352f39a93c..c942c0248fefd 100644 --- a/mlir/lib/Dialect/Affine/Transforms/PipelineDataTransfer.cpp +++ b/mlir/lib/Dialect/Affine/Transforms/PipelineDataTransfer.cpp @@ -100,16 +100,16 @@ static bool doubleBuffer(Value oldMemRef, AffineForOp forOp) { } // Create and place the alloc right before the 'affine.for' operation. - Value newMemRef = bOuter.create( - forOp.getLoc(), newMemRefType, allocOperands); + Value newMemRef = memref::AllocOp::create(bOuter, forOp.getLoc(), + newMemRefType, allocOperands); // Create 'iv mod 2' value to index the leading dimension. auto d0 = bInner.getAffineDimExpr(0); int64_t step = forOp.getStepAsInt(); auto modTwoMap = AffineMap::get(/*dimCount=*/1, /*symbolCount=*/0, d0.floorDiv(step) % 2); - auto ivModTwoOp = bInner.create(forOp.getLoc(), modTwoMap, - forOp.getInductionVar()); + auto ivModTwoOp = AffineApplyOp::create(bInner, forOp.getLoc(), modTwoMap, + forOp.getInductionVar()); // replaceAllMemRefUsesWith will succeed unless the forOp body has // non-dereferencing uses of the memref (dealloc's are fine though). @@ -130,7 +130,7 @@ static bool doubleBuffer(Value oldMemRef, AffineForOp forOp) { } // Insert the dealloc op right after the for loop. bOuter.setInsertionPointAfter(forOp); - bOuter.create(forOp.getLoc(), newMemRef); + memref::DeallocOp::create(bOuter, forOp.getLoc(), newMemRef); return true; } diff --git a/mlir/lib/Dialect/Affine/Transforms/ReifyValueBounds.cpp b/mlir/lib/Dialect/Affine/Transforms/ReifyValueBounds.cpp index 1a266b72d1f8d..9537d3e75c26a 100644 --- a/mlir/lib/Dialect/Affine/Transforms/ReifyValueBounds.cpp +++ b/mlir/lib/Dialect/Affine/Transforms/ReifyValueBounds.cpp @@ -51,10 +51,10 @@ OpFoldResult affine::materializeComputedBound( "expected dynamic dim"); if (isa(value.getType())) { // A tensor dimension is used: generate a tensor.dim. - operands.push_back(b.create(loc, value, *dim)); + operands.push_back(tensor::DimOp::create(b, loc, value, *dim)); } else if (isa(value.getType())) { // A memref dimension is used: generate a memref.dim. - operands.push_back(b.create(loc, value, *dim)); + operands.push_back(memref::DimOp::create(b, loc, value, *dim)); } else { llvm_unreachable("cannot generate DimOp for unsupported shaped type"); } @@ -76,7 +76,7 @@ OpFoldResult affine::materializeComputedBound( operands[expr.getPosition() + boundMap.getNumDims()]); // General case: build affine.apply op. return static_cast( - b.create(loc, boundMap, operands).getResult()); + affine::AffineApplyOp::create(b, loc, boundMap, operands).getResult()); } FailureOr mlir::affine::reifyShapedValueDimBound( diff --git a/mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp b/mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp index 7fae260767e0a..10da9070136c1 100644 --- a/mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp +++ b/mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp @@ -905,8 +905,8 @@ static void computeMemoryOpIndices(Operation *op, AffineMap map, for (auto resultExpr : map.getResults()) { auto singleResMap = AffineMap::get(map.getNumDims(), map.getNumSymbols(), resultExpr); - auto afOp = state.builder.create(op->getLoc(), singleResMap, - mapOperands); + auto afOp = AffineApplyOp::create(state.builder, op->getLoc(), singleResMap, + mapOperands); results.push_back(afOp); } } @@ -961,7 +961,7 @@ static arith::ConstantOp vectorizeConstant(arith::ConstantOp constOp, auto vecForOp = cast(parentOp); state.builder.setInsertionPointToStart(vecForOp.getBody()); auto newConstOp = - state.builder.create(constOp.getLoc(), vecAttr); + arith::ConstantOp::create(state.builder, constOp.getLoc(), vecAttr); // Register vector replacement for future uses in the scope. state.registerOpVectorReplacement(constOp, newConstOp); @@ -986,8 +986,8 @@ static Operation *vectorizeAffineApplyOp(AffineApplyOp applyOp, } } - auto newApplyOp = state.builder.create( - applyOp.getLoc(), applyOp.getAffineMap(), updatedOperands); + auto newApplyOp = AffineApplyOp::create( + state.builder, applyOp.getLoc(), applyOp.getAffineMap(), updatedOperands); // Register the new affine.apply result. state.registerValueScalarReplacement(applyOp.getResult(), @@ -1010,7 +1010,7 @@ static arith::ConstantOp createInitialVector(arith::AtomicRMWKind reductionKind, auto vecTy = getVectorType(scalarTy, state.strategy); auto vecAttr = DenseElementsAttr::get(vecTy, valueAttr); auto newConstOp = - state.builder.create(oldOperand.getLoc(), vecAttr); + arith::ConstantOp::create(state.builder, oldOperand.getLoc(), vecAttr); return newConstOp; } @@ -1062,11 +1062,11 @@ static Value createMask(AffineForOp vecForOp, VectorizationState &state) { AffineMap ubMap = vecForOp.getUpperBoundMap(); Value ub; if (ubMap.getNumResults() == 1) - ub = state.builder.create(loc, vecForOp.getUpperBoundMap(), - vecForOp.getUpperBoundOperands()); + ub = AffineApplyOp::create(state.builder, loc, vecForOp.getUpperBoundMap(), + vecForOp.getUpperBoundOperands()); else - ub = state.builder.create(loc, vecForOp.getUpperBoundMap(), - vecForOp.getUpperBoundOperands()); + ub = AffineMinOp::create(state.builder, loc, vecForOp.getUpperBoundMap(), + vecForOp.getUpperBoundOperands()); // Then we compute the number of (original) iterations left in the loop. AffineExpr subExpr = state.builder.getAffineDimExpr(0) - state.builder.getAffineDimExpr(1); @@ -1080,7 +1080,7 @@ static Value createMask(AffineForOp vecForOp, VectorizationState &state) { Type maskTy = VectorType::get(state.strategy->vectorSizes, state.builder.getIntegerType(1)); Value mask = - state.builder.create(loc, maskTy, itersLeft); + vector::CreateMaskOp::create(state.builder, loc, maskTy, itersLeft); LLVM_DEBUG(dbgs() << "\n[early-vect]+++++ creating a mask:\n" << itersLeft << "\n" @@ -1123,8 +1123,8 @@ static Operation *vectorizeUniform(Value uniformVal, state.builder.setInsertionPointAfterValue(uniformScalarRepl); auto vectorTy = getVectorType(uniformVal.getType(), state.strategy); - auto bcastOp = state.builder.create(uniformVal.getLoc(), - vectorTy, uniformScalarRepl); + auto bcastOp = BroadcastOp::create(state.builder, uniformVal.getLoc(), + vectorTy, uniformScalarRepl); state.registerValueVectorReplacement(uniformVal, bcastOp); return bcastOp; } @@ -1256,8 +1256,8 @@ static Operation *vectorizeAffineLoad(AffineLoadOp loadOp, LLVM_DEBUG(dbgs() << "\n[early-vect]+++++ permutationMap: "); LLVM_DEBUG(permutationMap.print(dbgs())); - auto transfer = state.builder.create( - loadOp.getLoc(), vectorType, loadOp.getMemRef(), indices, + auto transfer = vector::TransferReadOp::create( + state.builder, loadOp.getLoc(), vectorType, loadOp.getMemRef(), indices, /*padding=*/std::nullopt, permutationMap); // Register replacement for future uses in the scope. @@ -1303,9 +1303,9 @@ static Operation *vectorizeAffineStore(AffineStoreOp storeOp, LLVM_DEBUG(dbgs() << "\n[early-vect]+++++ permutationMap: "); LLVM_DEBUG(permutationMap.print(dbgs())); - auto transfer = state.builder.create( - storeOp.getLoc(), vectorValue, storeOp.getMemRef(), indices, - permutationMap); + auto transfer = vector::TransferWriteOp::create( + state.builder, storeOp.getLoc(), vectorValue, storeOp.getMemRef(), + indices, permutationMap); LLVM_DEBUG(dbgs() << "\n[early-vect]+++++ vectorized store: " << transfer); // Register replacement for future uses in the scope. @@ -1387,10 +1387,10 @@ static Operation *vectorizeAffineForOp(AffineForOp forOp, } } - auto vecForOp = state.builder.create( - forOp.getLoc(), forOp.getLowerBoundOperands(), forOp.getLowerBoundMap(), - forOp.getUpperBoundOperands(), forOp.getUpperBoundMap(), newStep, - vecIterOperands, + auto vecForOp = AffineForOp::create( + state.builder, forOp.getLoc(), forOp.getLowerBoundOperands(), + forOp.getLowerBoundMap(), forOp.getUpperBoundOperands(), + forOp.getUpperBoundMap(), newStep, vecIterOperands, /*bodyBuilder=*/[](OpBuilder &, Location, Value, ValueRange) { // Make sure we don't create a default terminator in the loop body as // the proper terminator will be added during vectorization. @@ -1512,8 +1512,8 @@ static Operation *vectorizeAffineYieldOp(AffineYieldOp yieldOp, // IterOperands are neutral element vectors. Value neutralVal = cast(newParentOp).getInits()[i]; state.builder.setInsertionPoint(combinerOps.back()); - Value maskedReducedVal = state.builder.create( - reducedVal.getLoc(), mask, reducedVal, neutralVal); + Value maskedReducedVal = arith::SelectOp::create( + state.builder, reducedVal.getLoc(), mask, reducedVal, neutralVal); LLVM_DEBUG( dbgs() << "\n[early-vect]+++++ masking an input to a binary op that" "produces value for a yield Op: " @@ -1865,7 +1865,6 @@ verifyLoopNesting(const std::vector> &loops) { return success(); } - /// External utility to vectorize affine loops in 'loops' using the n-D /// vectorization factors in 'vectorSizes'. By default, each vectorization /// factor is applied inner-to-outer to the loops of each loop nest. @@ -1927,4 +1926,4 @@ LogicalResult mlir::affine::vectorizeAffineLoopNest( if (failed(verifyLoopNesting(loops))) return failure(); return vectorizeLoopNest(loops, strategy); -} +} \ No newline at end of file diff --git a/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp b/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp index 21f69ad2d4c25..2de057d1d0758 100644 --- a/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp +++ b/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp @@ -54,8 +54,8 @@ getCleanupLoopLowerBound(AffineForOp forOp, unsigned unrollFactor, OpBuilder b(forOp); auto lbMap = forOp.getLowerBoundMap(); - auto lb = b.create(forOp.getLoc(), lbMap, - forOp.getLowerBoundOperands()); + auto lb = AffineApplyOp::create(b, forOp.getLoc(), lbMap, + forOp.getLowerBoundOperands()); // For each upper bound expr, get the range. // Eg: affine.for %i = lb to min (ub1, ub2), @@ -71,7 +71,7 @@ getCleanupLoopLowerBound(AffineForOp forOp, unsigned unrollFactor, auto bumpMap = AffineMap::get(tripCountMap.getNumDims(), tripCountMap.getNumSymbols(), bumpExprs[i]); bumpValues[i] = - b.create(forOp.getLoc(), bumpMap, tripCountOperands); + AffineApplyOp::create(b, forOp.getLoc(), bumpMap, tripCountOperands); } SmallVector newUbExprs(tripCountMap.getNumResults()); @@ -134,8 +134,8 @@ LogicalResult mlir::affine::promoteIfSingleIteration(AffineForOp forOp) { builder.setInsertionPointToStart(&func.getFunctionBody().front()); else builder.setInsertionPoint(forOp); - auto constOp = builder.create( - forOp.getLoc(), forOp.getConstantLowerBound()); + auto constOp = arith::ConstantIndexOp::create( + builder, forOp.getLoc(), forOp.getConstantLowerBound()); iv.replaceAllUsesWith(constOp); } else { auto lbOperands = forOp.getLowerBoundOperands(); @@ -146,7 +146,7 @@ LogicalResult mlir::affine::promoteIfSingleIteration(AffineForOp forOp) { iv.replaceAllUsesWith(lbOperands[0]); } else { auto affineApplyOp = - builder.create(forOp.getLoc(), lbMap, lbOperands); + AffineApplyOp::create(builder, forOp.getLoc(), lbMap, lbOperands); iv.replaceAllUsesWith(affineApplyOp); } } @@ -181,8 +181,8 @@ static AffineForOp generateShiftedLoop( assert(ubMap.getNumInputs() == ubOperands.size()); auto loopChunk = - b.create(srcForOp.getLoc(), lbOperands, lbMap, ubOperands, - ubMap, srcForOp.getStepAsInt()); + AffineForOp::create(b, srcForOp.getLoc(), lbOperands, lbMap, ubOperands, + ubMap, srcForOp.getStepAsInt()); auto loopChunkIV = loopChunk.getInductionVar(); auto srcIV = srcForOp.getInductionVar(); @@ -197,8 +197,8 @@ static AffineForOp generateShiftedLoop( // Generate the remapping if the shift is not zero: remappedIV = newIV - // shift. if (!srcIV.use_empty() && shift != 0) { - auto ivRemap = bodyBuilder.create( - srcForOp.getLoc(), + auto ivRemap = AffineApplyOp::create( + bodyBuilder, srcForOp.getLoc(), bodyBuilder.getSingleDimShiftAffineMap( -static_cast(srcForOp.getStepAsInt() * shift)), loopChunkIV); @@ -433,7 +433,7 @@ static void constructTiledLoopNest(MutableArrayRef origLoops, for (unsigned i = 0; i < width; i++) { OpBuilder b(topLoop); // Loop bounds will be set later. - AffineForOp pointLoop = b.create(loc, 0, 0); + AffineForOp pointLoop = AffineForOp::create(b, loc, 0, 0); pointLoop.getBody()->getOperations().splice( pointLoop.getBody()->begin(), topLoop->getBlock()->getOperations(), topLoop); @@ -447,7 +447,7 @@ static void constructTiledLoopNest(MutableArrayRef origLoops, for (unsigned i = width; i < 2 * width; i++) { OpBuilder b(topLoop); // Loop bounds will be set later. - AffineForOp tileSpaceLoop = b.create(loc, 0, 0); + AffineForOp tileSpaceLoop = AffineForOp::create(b, loc, 0, 0); tileSpaceLoop.getBody()->getOperations().splice( tileSpaceLoop.getBody()->begin(), topLoop->getBlock()->getOperations(), topLoop); @@ -1048,7 +1048,7 @@ LogicalResult mlir::affine::loopUnrollByFactor( // iv' = iv + i * step auto d0 = b.getAffineDimExpr(0); auto bumpMap = AffineMap::get(1, 0, d0 + i * step); - return b.create(forOp.getLoc(), bumpMap, iv); + return AffineApplyOp::create(b, forOp.getLoc(), bumpMap, iv); }, /*annotateFn=*/annotateFn, /*iterArgs=*/iterArgs, /*yieldedValues=*/yieldedValues); @@ -1212,7 +1212,7 @@ LogicalResult mlir::affine::loopUnrollJamByFactor(AffineForOp forOp, auto d0 = builder.getAffineDimExpr(0); auto bumpMap = AffineMap::get(1, 0, d0 + i * step); auto ivUnroll = - builder.create(forOp.getLoc(), bumpMap, forOpIV); + AffineApplyOp::create(builder, forOp.getLoc(), bumpMap, forOpIV); operandMaps[i - 1].map(forOpIV, ivUnroll); } // Clone the sub-block being unroll-jammed. @@ -1541,8 +1541,8 @@ stripmineSink(AffineForOp forOp, uint64_t factor, for (auto t : targets) { // Insert newForOp before the terminator of `t`. auto b = OpBuilder::atBlockTerminator(t.getBody()); - auto newForOp = b.create(t.getLoc(), lbOperands, lbMap, - ubOperands, ubMap, originalStep); + auto newForOp = AffineForOp::create(b, t.getLoc(), lbOperands, lbMap, + ubOperands, ubMap, originalStep); auto begin = t.getBody()->begin(); // Skip terminator and `newForOp` which is just before the terminator. auto nOps = t.getBody()->getOperations().size() - 2; @@ -1616,9 +1616,9 @@ LogicalResult mlir::affine::coalesceLoops(MutableArrayRef loops) { // 1. Store the upper bound of the outermost loop in a variable. Value prev; if (!llvm::hasSingleElement(origUbMap.getResults())) - prev = builder.create(loc, origUbMap, ubOperands); + prev = AffineMinOp::create(builder, loc, origUbMap, ubOperands); else - prev = builder.create(loc, origUbMap, ubOperands); + prev = AffineApplyOp::create(builder, loc, origUbMap, ubOperands); upperBoundSymbols.push_back(prev); // 2. Emit code computing the upper bound of the coalesced loop as product of @@ -1630,16 +1630,16 @@ LogicalResult mlir::affine::coalesceLoops(MutableArrayRef loops) { Value upperBound; // If upper bound map has more than one result, take their minimum. if (!llvm::hasSingleElement(origUbMap.getResults())) - upperBound = builder.create(loc, origUbMap, ubOperands); + upperBound = AffineMinOp::create(builder, loc, origUbMap, ubOperands); else - upperBound = builder.create(loc, origUbMap, ubOperands); + upperBound = AffineApplyOp::create(builder, loc, origUbMap, ubOperands); upperBoundSymbols.push_back(upperBound); SmallVector operands; operands.push_back(prev); operands.push_back(upperBound); // Maintain running product of loop upper bounds. - prev = builder.create( - loc, + prev = AffineApplyOp::create( + builder, loc, AffineMap::get(/*dimCount=*/1, /*symbolCount=*/1, builder.getAffineDimExpr(0) * @@ -1668,13 +1668,12 @@ LogicalResult mlir::affine::coalesceLoops(MutableArrayRef loops) { SmallVector operands; operands.push_back(previous); operands.push_back(upperBoundSymbols[idx]); - previous = builder.create( - loc, - AffineMap::get( - /*dimCount=*/1, /*symbolCount=*/1, - builder.getAffineDimExpr(0).floorDiv( - builder.getAffineSymbolExpr(0))), - operands); + previous = AffineApplyOp::create(builder, loc, + AffineMap::get( + /*dimCount=*/1, /*symbolCount=*/1, + builder.getAffineDimExpr(0).floorDiv( + builder.getAffineSymbolExpr(0))), + operands); } // Modified value of the induction variables of the nested loops after // coalescing. @@ -1685,8 +1684,8 @@ LogicalResult mlir::affine::coalesceLoops(MutableArrayRef loops) { SmallVector applyOperands; applyOperands.push_back(previous); applyOperands.push_back(upperBoundSymbols[idx - 1]); - inductionVariable = builder.create( - loc, + inductionVariable = AffineApplyOp::create( + builder, loc, AffineMap::get( /*dimCount=*/1, /*symbolCount=*/1, builder.getAffineDimExpr(0) % builder.getAffineSymbolExpr(0)), @@ -1723,21 +1722,21 @@ void mlir::affine::mapLoopToProcessorIds(scf::ForOp forOp, Value linearIndex = processorId.front(); for (unsigned i = 1, e = processorId.size(); i < e; ++i) { - auto mulApplyOp = b.create( - loc, mulMap, ValueRange{linearIndex, numProcessors[i]}); - linearIndex = b.create( - loc, addMap, ValueRange{mulApplyOp, processorId[i]}); + auto mulApplyOp = AffineApplyOp::create( + b, loc, mulMap, ValueRange{linearIndex, numProcessors[i]}); + linearIndex = AffineApplyOp::create(b, loc, addMap, + ValueRange{mulApplyOp, processorId[i]}); } - auto mulApplyOp = b.create( - loc, mulMap, ValueRange{linearIndex, forOp.getStep()}); - Value lb = b.create( - loc, addMap, ValueRange{mulApplyOp, forOp.getLowerBound()}); + auto mulApplyOp = AffineApplyOp::create( + b, loc, mulMap, ValueRange{linearIndex, forOp.getStep()}); + Value lb = AffineApplyOp::create( + b, loc, addMap, ValueRange{mulApplyOp, forOp.getLowerBound()}); forOp.setLowerBound(lb); Value step = forOp.getStep(); for (auto numProcs : numProcessors) - step = b.create(loc, mulMap, ValueRange{numProcs, step}); + step = AffineApplyOp::create(b, loc, mulMap, ValueRange{numProcs, step}); forOp.setStep(step); } @@ -1874,7 +1873,7 @@ generatePointWiseCopy(Location loc, Value memref, Value fastMemRef, auto fastBufOffsetMap = AffineMap::get(lbOperands.size(), 0, fastBufOffsets[d]); - auto offset = b.create(loc, fastBufOffsetMap, lbOperands); + auto offset = AffineApplyOp::create(b, loc, fastBufOffsetMap, lbOperands); // Construct the subscript for the fast memref being copied into/from: // x - offset_x. @@ -1901,16 +1900,16 @@ generatePointWiseCopy(Location loc, Value memref, Value fastMemRef, if (!isCopyOut) { // Copy in. - auto load = b.create(loc, memref, memIndices); - b.create(loc, load, fastMemRef, fastBufMap, - fastBufMapOperands); + auto load = AffineLoadOp::create(b, loc, memref, memIndices); + AffineStoreOp::create(b, loc, load, fastMemRef, fastBufMap, + fastBufMapOperands); return copyNestRoot; } // Copy out. auto load = - b.create(loc, fastMemRef, fastBufMap, fastBufMapOperands); - b.create(loc, load, memref, memIndices); + AffineLoadOp::create(b, loc, fastMemRef, fastBufMap, fastBufMapOperands); + AffineStoreOp::create(b, loc, load, memref, memIndices); return copyNestRoot; } @@ -1945,7 +1944,7 @@ static LogicalResult generateCopy( auto f = begin->getParentOfType(); OpBuilder topBuilder(f.getFunctionBody()); - Value zeroIndex = topBuilder.create(f.getLoc(), 0); + Value zeroIndex = arith::ConstantIndexOp::create(topBuilder, f.getLoc(), 0); *sizeInBytes = 0; @@ -2056,7 +2055,7 @@ static LogicalResult generateCopy( memIndices.push_back(zeroIndex); } else { memIndices.push_back( - top.create(loc, indexVal).getResult()); + arith::ConstantIndexOp::create(top, loc, indexVal).getResult()); } } else { // The coordinate for the start location is just the lower bound along the @@ -2070,7 +2069,8 @@ static LogicalResult generateCopy( lbs[d] = lbs[d].replaceDimsAndSymbols( /*dimReplacements=*/{}, symReplacements, lbs[d].getNumSymbols(), /*numResultSyms=*/0); - memIndices.push_back(b.create(loc, lbs[d], regionSymbols)); + memIndices.push_back( + AffineApplyOp::create(b, loc, lbs[d], regionSymbols)); } // The fast buffer is copied into at location zero; addressing is relative. bufIndices.push_back(zeroIndex); @@ -2094,7 +2094,7 @@ static LogicalResult generateCopy( // Create the fast memory space buffer just before the 'affine.for' // operation. fastMemRef = - prologue.create(loc, fastMemRefType).getResult(); + memref::AllocOp::create(prologue, loc, fastMemRefType).getResult(); // Record it. fastBufferMap[memref] = fastMemRef; // fastMemRefType is a constant shaped memref. @@ -2111,7 +2111,7 @@ static LogicalResult generateCopy( fastMemRef = fastBufferMap[memref]; } - auto numElementsSSA = top.create(loc, *numElements); + auto numElementsSSA = arith::ConstantIndexOp::create(top, loc, *numElements); Value dmaStride; Value numEltPerDmaStride; @@ -2128,9 +2128,9 @@ static LogicalResult generateCopy( if (!dmaStrideInfos.empty()) { dmaStride = - top.create(loc, dmaStrideInfos[0].stride); - numEltPerDmaStride = top.create( - loc, dmaStrideInfos[0].numEltPerStride); + arith::ConstantIndexOp::create(top, loc, dmaStrideInfos[0].stride); + numEltPerDmaStride = arith::ConstantIndexOp::create( + top, loc, dmaStrideInfos[0].numEltPerStride); } } @@ -2160,21 +2160,21 @@ static LogicalResult generateCopy( // Create a tag (single element 1-d memref) for the DMA. auto tagMemRefType = MemRefType::get({1}, top.getIntegerType(32), {}, copyOptions.tagMemorySpace); - auto tagMemRef = prologue.create(loc, tagMemRefType); + auto tagMemRef = memref::AllocOp::create(prologue, loc, tagMemRefType); SmallVector tagIndices({zeroIndex}); auto tagAffineMap = b.getMultiDimIdentityMap(tagIndices.size()); fullyComposeAffineMapAndOperands(&tagAffineMap, &tagIndices); if (!region.isWrite()) { // DMA non-blocking read from original buffer to fast buffer. - b.create(loc, memref, memAffineMap, memIndices, - fastMemRef, bufAffineMap, bufIndices, - tagMemRef, tagAffineMap, tagIndices, - numElementsSSA, dmaStride, numEltPerDmaStride); + AffineDmaStartOp::create(b, loc, memref, memAffineMap, memIndices, + fastMemRef, bufAffineMap, bufIndices, tagMemRef, + tagAffineMap, tagIndices, numElementsSSA, + dmaStride, numEltPerDmaStride); } else { // DMA non-blocking write from fast buffer to the original memref. - auto op = b.create( - loc, fastMemRef, bufAffineMap, bufIndices, memref, memAffineMap, + auto op = AffineDmaStartOp::create( + b, loc, fastMemRef, bufAffineMap, bufIndices, memref, memAffineMap, memIndices, tagMemRef, tagAffineMap, tagIndices, numElementsSSA, dmaStride, numEltPerDmaStride); // Since new ops may be appended at 'end' (for outgoing DMAs), adjust the @@ -2184,11 +2184,11 @@ static LogicalResult generateCopy( } // Matching DMA wait to block on completion; tag always has a 0 index. - b.create(loc, tagMemRef, tagAffineMap, zeroIndex, - numElementsSSA); + AffineDmaWaitOp::create(b, loc, tagMemRef, tagAffineMap, zeroIndex, + numElementsSSA); // Generate dealloc for the tag. - auto tagDeallocOp = epilogue.create(loc, tagMemRef); + auto tagDeallocOp = memref::DeallocOp::create(epilogue, loc, tagMemRef); if (*nEnd == end && isCopyOutAtEndOfBlock) // Since new ops are being appended (for outgoing DMAs), adjust the end to // mark end of range of the original. @@ -2197,7 +2197,7 @@ static LogicalResult generateCopy( // Generate dealloc for the buffer. if (!existingBuf) { - auto bufDeallocOp = epilogue.create(loc, fastMemRef); + auto bufDeallocOp = memref::DeallocOp::create(epilogue, loc, fastMemRef); // When generating pointwise copies, `nEnd' has to be set to deallocOp on // the fast buffer (since it marks the new end insertion point). if (!copyOptions.generateDma && *nEnd == end && isCopyOutAtEndOfBlock) @@ -2567,8 +2567,8 @@ AffineForOp mlir::affine::createCanonicalizedAffineForOp( canonicalizeMapAndOperands(&ubMap, &upperOperands); ubMap = removeDuplicateExprs(ubMap); - return b.create(loc, lowerOperands, lbMap, upperOperands, ubMap, - step); + return AffineForOp::create(b, loc, lowerOperands, lbMap, upperOperands, ubMap, + step); } /// Creates an AffineIfOp that encodes the conditional to choose between @@ -2651,8 +2651,8 @@ static AffineIfOp createSeparationCondition(MutableArrayRef loops, SmallVector setOperands; cst.getValues(0, cst.getNumDimAndSymbolVars(), &setOperands); canonicalizeSetAndOperands(&ifCondSet, &setOperands); - return b.create(loops[0].getLoc(), ifCondSet, setOperands, - /*withElseRegion=*/true); + return AffineIfOp::create(b, loops[0].getLoc(), ifCondSet, setOperands, + /*withElseRegion=*/true); } /// Create the full tile loop nest (along with its body). diff --git a/mlir/lib/Dialect/Affine/Utils/Utils.cpp b/mlir/lib/Dialect/Affine/Utils/Utils.cpp index 7bb158eb6dfc0..845be20d15b69 100644 --- a/mlir/lib/Dialect/Affine/Utils/Utils.cpp +++ b/mlir/lib/Dialect/Affine/Utils/Utils.cpp @@ -56,7 +56,7 @@ class AffineApplyExpander auto rhs = visit(expr.getRHS()); if (!lhs || !rhs) return nullptr; - auto op = builder.create(loc, lhs, rhs, overflowFlags); + auto op = OpTy::create(builder, loc, lhs, rhs, overflowFlags); return op.getResult(); } @@ -90,14 +90,14 @@ class AffineApplyExpander auto rhs = visit(expr.getRHS()); assert(lhs && rhs && "unexpected affine expr lowering failure"); - Value remainder = builder.create(loc, lhs, rhs); - Value zeroCst = builder.create(loc, 0); - Value isRemainderNegative = builder.create( - loc, arith::CmpIPredicate::slt, remainder, zeroCst); + Value remainder = arith::RemSIOp::create(builder, loc, lhs, rhs); + Value zeroCst = arith::ConstantIndexOp::create(builder, loc, 0); + Value isRemainderNegative = arith::CmpIOp::create( + builder, loc, arith::CmpIPredicate::slt, remainder, zeroCst); Value correctedRemainder = - builder.create(loc, remainder, rhs); - Value result = builder.create( - loc, isRemainderNegative, correctedRemainder, remainder); + arith::AddIOp::create(builder, loc, remainder, rhs); + Value result = arith::SelectOp::create(builder, loc, isRemainderNegative, + correctedRemainder, remainder); return result; } @@ -129,18 +129,19 @@ class AffineApplyExpander auto rhs = visit(expr.getRHS()); assert(lhs && rhs && "unexpected affine expr lowering failure"); - Value zeroCst = builder.create(loc, 0); - Value noneCst = builder.create(loc, -1); - Value negative = builder.create( - loc, arith::CmpIPredicate::slt, lhs, zeroCst); - Value negatedDecremented = builder.create(loc, noneCst, lhs); - Value dividend = - builder.create(loc, negative, negatedDecremented, lhs); - Value quotient = builder.create(loc, dividend, rhs); + Value zeroCst = arith::ConstantIndexOp::create(builder, loc, 0); + Value noneCst = arith::ConstantIndexOp::create(builder, loc, -1); + Value negative = arith::CmpIOp::create( + builder, loc, arith::CmpIPredicate::slt, lhs, zeroCst); + Value negatedDecremented = + arith::SubIOp::create(builder, loc, noneCst, lhs); + Value dividend = arith::SelectOp::create(builder, loc, negative, + negatedDecremented, lhs); + Value quotient = arith::DivSIOp::create(builder, loc, dividend, rhs); Value correctedQuotient = - builder.create(loc, noneCst, quotient); - Value result = builder.create(loc, negative, - correctedQuotient, quotient); + arith::SubIOp::create(builder, loc, noneCst, quotient); + Value result = arith::SelectOp::create(builder, loc, negative, + correctedQuotient, quotient); return result; } @@ -168,26 +169,26 @@ class AffineApplyExpander auto rhs = visit(expr.getRHS()); assert(lhs && rhs && "unexpected affine expr lowering failure"); - Value zeroCst = builder.create(loc, 0); - Value oneCst = builder.create(loc, 1); - Value nonPositive = builder.create( - loc, arith::CmpIPredicate::sle, lhs, zeroCst); - Value negated = builder.create(loc, zeroCst, lhs); - Value decremented = builder.create(loc, lhs, oneCst); - Value dividend = - builder.create(loc, nonPositive, negated, decremented); - Value quotient = builder.create(loc, dividend, rhs); + Value zeroCst = arith::ConstantIndexOp::create(builder, loc, 0); + Value oneCst = arith::ConstantIndexOp::create(builder, loc, 1); + Value nonPositive = arith::CmpIOp::create( + builder, loc, arith::CmpIPredicate::sle, lhs, zeroCst); + Value negated = arith::SubIOp::create(builder, loc, zeroCst, lhs); + Value decremented = arith::SubIOp::create(builder, loc, lhs, oneCst); + Value dividend = arith::SelectOp::create(builder, loc, nonPositive, negated, + decremented); + Value quotient = arith::DivSIOp::create(builder, loc, dividend, rhs); Value negatedQuotient = - builder.create(loc, zeroCst, quotient); + arith::SubIOp::create(builder, loc, zeroCst, quotient); Value incrementedQuotient = - builder.create(loc, quotient, oneCst); - Value result = builder.create( - loc, nonPositive, negatedQuotient, incrementedQuotient); + arith::AddIOp::create(builder, loc, quotient, oneCst); + Value result = arith::SelectOp::create( + builder, loc, nonPositive, negatedQuotient, incrementedQuotient); return result; } Value visitConstantExpr(AffineConstantExpr expr) { - auto op = builder.create(loc, expr.getValue()); + auto op = arith::ConstantIndexOp::create(builder, loc, expr.getValue()); return op.getResult(); } @@ -297,9 +298,9 @@ static AffineIfOp hoistAffineIfOp(AffineIfOp ifOp, Operation *hoistOverOp) { // block. IRMapping operandMap; OpBuilder b(hoistOverOp); - auto hoistedIfOp = b.create(ifOp.getLoc(), ifOp.getIntegerSet(), - ifOp.getOperands(), - /*elseBlock=*/true); + auto hoistedIfOp = AffineIfOp::create(b, ifOp.getLoc(), ifOp.getIntegerSet(), + ifOp.getOperands(), + /*elseBlock=*/true); // Create a clone of hoistOverOp to use for the else branch of the hoisted // conditional. The else block may get optimized away if empty. @@ -368,8 +369,8 @@ mlir::affine::affineParallelize(AffineForOp forOp, parallelReductions, [](const LoopReduction &red) { return red.value; })); auto reductionKinds = llvm::to_vector<4>(llvm::map_range( parallelReductions, [](const LoopReduction &red) { return red.kind; })); - AffineParallelOp newPloop = outsideBuilder.create( - loc, ValueRange(reducedValues).getTypes(), reductionKinds, + AffineParallelOp newPloop = AffineParallelOp::create( + outsideBuilder, loc, ValueRange(reducedValues).getTypes(), reductionKinds, llvm::ArrayRef(lowerBoundMap), lowerBoundOperands, llvm::ArrayRef(upperBoundMap), upperBoundOperands, llvm::ArrayRef(forOp.getStepAsInt())); @@ -540,7 +541,8 @@ void mlir::affine::normalizeAffineParallel(AffineParallelOp op) { SmallVector applyOperands{dimOperands}; applyOperands.push_back(iv); applyOperands.append(symbolOperands.begin(), symbolOperands.end()); - auto apply = builder.create(op.getLoc(), map, applyOperands); + auto apply = + AffineApplyOp::create(builder, op.getLoc(), map, applyOperands); iv.replaceAllUsesExcept(apply, apply); } @@ -621,8 +623,9 @@ LogicalResult mlir::affine::normalizeAffineFor(AffineForOp op, AffineValueMap newIvToOldIvMap; AffineValueMap::difference(lbMap, scaleIvValueMap, &newIvToOldIvMap); (void)newIvToOldIvMap.canonicalize(); - auto newIV = opBuilder.create( - loc, newIvToOldIvMap.getAffineMap(), newIvToOldIvMap.getOperands()); + auto newIV = + AffineApplyOp::create(opBuilder, loc, newIvToOldIvMap.getAffineMap(), + newIvToOldIvMap.getOperands()); op.getInductionVar().replaceAllUsesExcept(newIV->getResult(0), newIV); return success(); } @@ -1186,8 +1189,8 @@ LogicalResult mlir::affine::replaceAllMemRefUsesWith( for (auto resultExpr : oldMap.getResults()) { auto singleResMap = AffineMap::get(oldMap.getNumDims(), oldMap.getNumSymbols(), resultExpr); - auto afOp = builder.create(op->getLoc(), singleResMap, - oldMapOperands); + auto afOp = AffineApplyOp::create(builder, op->getLoc(), singleResMap, + oldMapOperands); oldMemRefOperands.push_back(afOp); affineApplyOps.push_back(afOp); } @@ -1213,8 +1216,8 @@ LogicalResult mlir::affine::replaceAllMemRefUsesWith( for (auto resultExpr : indexRemap.getResults()) { auto singleResMap = AffineMap::get( indexRemap.getNumDims(), indexRemap.getNumSymbols(), resultExpr); - auto afOp = builder.create(op->getLoc(), singleResMap, - remapOperands); + auto afOp = AffineApplyOp::create(builder, op->getLoc(), singleResMap, + remapOperands); remapOutputs.push_back(afOp); affineApplyOps.push_back(afOp); } @@ -1263,8 +1266,8 @@ LogicalResult mlir::affine::replaceAllMemRefUsesWith( // AffineMapAccessInterface, we need to apply the values of `newMapOperands` // to the `newMap` to get the correct indices. for (unsigned i = 0; i < newMemRefRank; i++) { - state.operands.push_back(builder.create( - op->getLoc(), + state.operands.push_back(AffineApplyOp::create( + builder, op->getLoc(), AffineMap::get(newMap.getNumDims(), newMap.getNumSymbols(), newMap.getResult(i)), newMapOperands)); @@ -1449,8 +1452,8 @@ void mlir::affine::createAffineComputationSlice( for (auto resultExpr : composedMap.getResults()) { auto singleResMap = AffineMap::get(composedMap.getNumDims(), composedMap.getNumSymbols(), resultExpr); - sliceOps->push_back(builder.create( - opInst->getLoc(), singleResMap, composedOpOperands)); + sliceOps->push_back(AffineApplyOp::create( + builder, opInst->getLoc(), singleResMap, composedOpOperands)); } // Construct the new operands that include the results from the composed @@ -1680,7 +1683,7 @@ static void createNewDynamicSizes(MemRefType oldMemRefType, // Create ConstantOp for static dimension. auto constantAttr = b.getIntegerAttr(b.getIndexType(), oldMemRefShape[d]); inAffineApply.emplace_back( - b.create(allocOp.getLoc(), constantAttr)); + arith::ConstantOp::create(b, allocOp.getLoc(), constantAttr)); } } @@ -1704,7 +1707,7 @@ static void createNewDynamicSizes(MemRefType oldMemRefType, AffineMap newMap = AffineMap::get(map.getNumInputs(), map.getNumSymbols(), newMapOutput); Value affineApp = - b.create(allocOp.getLoc(), newMap, inAffineApply); + AffineApplyOp::create(b, allocOp.getLoc(), newMap, inAffineApply); newDynamicSizes.emplace_back(affineApp); } newDimIdx++; @@ -1739,12 +1742,11 @@ LogicalResult mlir::affine::normalizeMemRef(AllocLikeOp allocOp) { createNewDynamicSizes(oldMemRefType, newMemRefType, layoutMap, allocOp, b, newDynamicSizes); // Add the new dynamic sizes in new AllocOp. - newAlloc = - b.create(allocOp.getLoc(), newMemRefType, newDynamicSizes, - allocOp.getAlignmentAttr()); + newAlloc = AllocLikeOp::create(b, allocOp.getLoc(), newMemRefType, + newDynamicSizes, allocOp.getAlignmentAttr()); } else { - newAlloc = b.create(allocOp.getLoc(), newMemRefType, - allocOp.getAlignmentAttr()); + newAlloc = AllocLikeOp::create(b, allocOp.getLoc(), newMemRefType, + allocOp.getAlignmentAttr()); } // Replace all uses of the old memref. if (failed(replaceAllMemRefUsesWith(oldMemRef, /*newMemRef=*/newAlloc, @@ -1802,10 +1804,10 @@ mlir::affine::normalizeMemRef(memref::ReinterpretCastOp reinterpretCastOp) { for (unsigned i = 0, e = memrefType.getRank(); i < e; i++) { if (memrefType.isDynamicDim(i)) mapOperands[i] = - b.create(loc, oldSizes[0].getType(), oldSizes[idx++], - b.create(loc, 1)); + arith::SubIOp::create(b, loc, oldSizes[0].getType(), oldSizes[idx++], + arith::ConstantIndexOp::create(b, loc, 1)); else - mapOperands[i] = b.create(loc, oldShape[i] - 1); + mapOperands[i] = arith::ConstantIndexOp::create(b, loc, oldShape[i] - 1); } for (unsigned i = 0, e = oldStrides.size(); i < e; i++) mapOperands[memrefType.getRank() + i] = oldStrides[i]; @@ -1815,20 +1817,20 @@ mlir::affine::normalizeMemRef(memref::ReinterpretCastOp reinterpretCastOp) { for (unsigned i = 0; i < newRank; i++) { if (!newMemRefType.isDynamicDim(i)) continue; - newSizes.push_back(b.create( - loc, + newSizes.push_back(AffineApplyOp::create( + b, loc, AffineMap::get(oldLayoutMap.getNumDims(), oldLayoutMap.getNumSymbols(), oldLayoutMap.getResult(i)), mapOperands)); } for (unsigned i = 0, e = newSizes.size(); i < e; i++) { newSizes[i] = - b.create(loc, newSizes[i].getType(), newSizes[i], - b.create(loc, 1)); + arith::AddIOp::create(b, loc, newSizes[i].getType(), newSizes[i], + arith::ConstantIndexOp::create(b, loc, 1)); } // Create the new reinterpret_cast op. - auto newReinterpretCast = b.create( - loc, newMemRefType, reinterpretCastOp.getSource(), + auto newReinterpretCast = memref::ReinterpretCastOp::create( + b, loc, newMemRefType, reinterpretCastOp.getSource(), /*offsets=*/ValueRange(), newSizes, /*strides=*/ValueRange(), /*static_offsets=*/newStaticOffsets, diff --git a/mlir/lib/Dialect/Arith/IR/ArithDialect.cpp b/mlir/lib/Dialect/Arith/IR/ArithDialect.cpp index ebcb951cf3518..e7cbee6b06c45 100644 --- a/mlir/lib/Dialect/Arith/IR/ArithDialect.cpp +++ b/mlir/lib/Dialect/Arith/IR/ArithDialect.cpp @@ -64,7 +64,7 @@ Operation *arith::ArithDialect::materializeConstant(OpBuilder &builder, Attribute value, Type type, Location loc) { if (auto poison = dyn_cast(value)) - return builder.create(loc, type, poison); + return ub::PoisonOp::create(builder, loc, type, poison); return ConstantOp::materialize(builder, value, type, loc); } diff --git a/mlir/lib/Dialect/Arith/Transforms/BufferDeallocationOpInterfaceImpl.cpp b/mlir/lib/Dialect/Arith/Transforms/BufferDeallocationOpInterfaceImpl.cpp index f2e7732e8ea4a..9199dccdcaff3 100644 --- a/mlir/lib/Dialect/Arith/Transforms/BufferDeallocationOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Arith/Transforms/BufferDeallocationOpInterfaceImpl.cpp @@ -67,8 +67,8 @@ struct SelectOpInterface return state.getMemrefWithUniqueOwnership(builder, value, value.getParentBlock()); - Value ownership = builder.create( - op->getLoc(), selectOp.getCondition(), + Value ownership = arith::SelectOp::create( + builder, op->getLoc(), selectOp.getCondition(), state.getOwnership(selectOp.getTrueValue(), block).getIndicator(), state.getOwnership(selectOp.getFalseValue(), block).getIndicator()); return {selectOp.getResult(), ownership}; diff --git a/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp index afee162053bea..b073a31850678 100644 --- a/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp @@ -170,10 +170,10 @@ struct SelectOpInterface return failure(); if (trueBuffer.getType() != *targetType) trueBuffer = - rewriter.create(loc, *targetType, trueBuffer); + memref::CastOp::create(rewriter, loc, *targetType, trueBuffer); if (falseBuffer.getType() != *targetType) falseBuffer = - rewriter.create(loc, *targetType, falseBuffer); + memref::CastOp::create(rewriter, loc, *targetType, falseBuffer); } replaceOpWithNewBufferizedOp( diff --git a/mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp b/mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp index 55b757c136127..7626d356a37f2 100644 --- a/mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp +++ b/mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp @@ -75,7 +75,7 @@ LogicalResult EmulateFloatPattern::matchAndRewrite( for (auto [res, oldType, newType] : llvm::zip_equal( MutableArrayRef{newResults}, op->getResultTypes(), resultTypes)) { if (oldType != newType) { - auto truncFOp = rewriter.create(loc, oldType, res); + auto truncFOp = arith::TruncFOp::create(rewriter, loc, oldType, res); truncFOp.setFastmath(arith::FastMathFlags::contract); res = truncFOp.getResult(); } @@ -98,7 +98,7 @@ void mlir::arith::populateEmulateUnsupportedFloatsConversions( }); converter.addTargetMaterialization( [](OpBuilder &b, Type target, ValueRange input, Location loc) { - auto extFOp = b.create(loc, target, input); + auto extFOp = arith::ExtFOp::create(b, loc, target, input); extFOp.setFastmath(arith::FastMathFlags::contract); return extFOp; }); diff --git a/mlir/lib/Dialect/Arith/Transforms/EmulateWideInt.cpp b/mlir/lib/Dialect/Arith/Transforms/EmulateWideInt.cpp index d5d1559c658ff..efe6ad2579055 100644 --- a/mlir/lib/Dialect/Arith/Transforms/EmulateWideInt.cpp +++ b/mlir/lib/Dialect/Arith/Transforms/EmulateWideInt.cpp @@ -72,7 +72,7 @@ static Value extractLastDimSlice(ConversionPatternRewriter &rewriter, // Scalarize the result in case of 1D vectors. if (shape.size() == 1) - return rewriter.create(loc, input, lastOffset); + return vector::ExtractOp::create(rewriter, loc, input, lastOffset); SmallVector offsets(shape.size(), 0); offsets.back() = lastOffset; @@ -80,8 +80,8 @@ static Value extractLastDimSlice(ConversionPatternRewriter &rewriter, sizes.back() = 1; SmallVector strides(shape.size(), 1); - return rewriter.create(loc, input, offsets, - sizes, strides); + return vector::ExtractStridedSliceOp::create(rewriter, loc, input, offsets, + sizes, strides); } /// Extracts two vector slices from the `input` whose type is `vector<...x2T>`, @@ -107,7 +107,7 @@ static Value dropTrailingX1Dim(ConversionPatternRewriter &rewriter, assert(shape.back() == 1 && "Expected the last vector dim to be x1"); auto newVecTy = VectorType::get(shape.drop_back(), vecTy.getElementType()); - return rewriter.create(loc, newVecTy, input); + return vector::ShapeCastOp::create(rewriter, loc, newVecTy, input); } /// Performs a vector shape cast to append an x1 dimension. If the @@ -122,7 +122,7 @@ static Value appendX1Dim(ConversionPatternRewriter &rewriter, Location loc, auto newShape = llvm::to_vector(vecTy.getShape()); newShape.push_back(1); auto newTy = VectorType::get(newShape, vecTy.getElementType()); - return rewriter.create(loc, newTy, input); + return vector::ShapeCastOp::create(rewriter, loc, newTy, input); } /// Inserts the `source` vector slice into the `dest` vector at offset @@ -136,13 +136,13 @@ static Value insertLastDimSlice(ConversionPatternRewriter &rewriter, // Handle scalar source. if (isa(source.getType())) - return rewriter.create(loc, source, dest, lastOffset); + return vector::InsertOp::create(rewriter, loc, source, dest, lastOffset); SmallVector offsets(shape.size(), 0); offsets.back() = lastOffset; SmallVector strides(shape.size(), 1); - return rewriter.create(loc, source, dest, - offsets, strides); + return vector::InsertStridedSliceOp::create(rewriter, loc, source, dest, + offsets, strides); } /// Constructs a new vector of type `resultType` by creating a series of @@ -254,12 +254,12 @@ struct ConvertAddI final : OpConversionPattern { extractLastDimHalves(rewriter, loc, adaptor.getRhs()); auto lowSum = - rewriter.create(loc, lhsElem0, rhsElem0); + arith::AddUIExtendedOp::create(rewriter, loc, lhsElem0, rhsElem0); Value overflowVal = - rewriter.create(loc, newElemTy, lowSum.getOverflow()); + arith::ExtUIOp::create(rewriter, loc, newElemTy, lowSum.getOverflow()); - Value high0 = rewriter.create(loc, overflowVal, lhsElem1); - Value high = rewriter.create(loc, high0, rhsElem1); + Value high0 = arith::AddIOp::create(rewriter, loc, overflowVal, lhsElem1); + Value high = arith::AddIOp::create(rewriter, loc, high0, rhsElem1); Value resultVec = constructResultVector(rewriter, loc, newTy, {lowSum.getSum(), high}); @@ -293,8 +293,8 @@ struct ConvertBitwiseBinary final : OpConversionPattern { auto [rhsElem0, rhsElem1] = extractLastDimHalves(rewriter, loc, adaptor.getRhs()); - Value resElem0 = rewriter.create(loc, lhsElem0, rhsElem0); - Value resElem1 = rewriter.create(loc, lhsElem1, rhsElem1); + Value resElem0 = BinaryOp::create(rewriter, loc, lhsElem0, rhsElem0); + Value resElem1 = BinaryOp::create(rewriter, loc, lhsElem1, rhsElem1); Value resultVec = constructResultVector(rewriter, loc, newTy, {resElem0, resElem1}); rewriter.replaceOp(op, resultVec); @@ -346,26 +346,26 @@ struct ConvertCmpI final : OpConversionPattern { extractLastDimHalves(rewriter, loc, adaptor.getRhs()); Value lowCmp = - rewriter.create(loc, lowPred, lhsElem0, rhsElem0); + arith::CmpIOp::create(rewriter, loc, lowPred, lhsElem0, rhsElem0); Value highCmp = - rewriter.create(loc, highPred, lhsElem1, rhsElem1); + arith::CmpIOp::create(rewriter, loc, highPred, lhsElem1, rhsElem1); Value cmpResult{}; switch (highPred) { case arith::CmpIPredicate::eq: { - cmpResult = rewriter.create(loc, lowCmp, highCmp); + cmpResult = arith::AndIOp::create(rewriter, loc, lowCmp, highCmp); break; } case arith::CmpIPredicate::ne: { - cmpResult = rewriter.create(loc, lowCmp, highCmp); + cmpResult = arith::OrIOp::create(rewriter, loc, lowCmp, highCmp); break; } default: { // Handle inequality checks. - Value highEq = rewriter.create( - loc, arith::CmpIPredicate::eq, lhsElem1, rhsElem1); + Value highEq = arith::CmpIOp::create( + rewriter, loc, arith::CmpIPredicate::eq, lhsElem1, rhsElem1); cmpResult = - rewriter.create(loc, highEq, lowCmp, highCmp); + arith::SelectOp::create(rewriter, loc, highEq, lowCmp, highCmp); break; } } @@ -401,14 +401,14 @@ struct ConvertMulI final : OpConversionPattern { // Multiplying two i2N integers produces (at most) an i4N result, but // because the calculation of top i2N is not necessary, we omit it. auto mulLowLow = - rewriter.create(loc, lhsElem0, rhsElem0); - Value mulLowHi = rewriter.create(loc, lhsElem0, rhsElem1); - Value mulHiLow = rewriter.create(loc, lhsElem1, rhsElem0); + arith::MulUIExtendedOp::create(rewriter, loc, lhsElem0, rhsElem0); + Value mulLowHi = arith::MulIOp::create(rewriter, loc, lhsElem0, rhsElem1); + Value mulHiLow = arith::MulIOp::create(rewriter, loc, lhsElem1, rhsElem0); Value resLow = mulLowLow.getLow(); Value resHi = - rewriter.create(loc, mulLowLow.getHigh(), mulLowHi); - resHi = rewriter.create(loc, resHi, mulHiLow); + arith::AddIOp::create(rewriter, loc, mulLowLow.getHigh(), mulLowHi); + resHi = arith::AddIOp::create(rewriter, loc, resHi, mulHiLow); Value resultVec = constructResultVector(rewriter, loc, newTy, {resLow, resHi}); @@ -443,10 +443,10 @@ struct ConvertExtSI final : OpConversionPattern { loc, newResultComponentTy, newOperand); Value operandZeroCst = createScalarOrSplatConstant(rewriter, loc, newResultComponentTy, 0); - Value signBit = rewriter.create( - loc, arith::CmpIPredicate::slt, extended, operandZeroCst); + Value signBit = arith::CmpIOp::create( + rewriter, loc, arith::CmpIPredicate::slt, extended, operandZeroCst); Value signValue = - rewriter.create(loc, newResultComponentTy, signBit); + arith::ExtSIOp::create(rewriter, loc, newResultComponentTy, signBit); Value resultVec = constructResultVector(rewriter, loc, newTy, {extended, signValue}); @@ -508,7 +508,7 @@ struct ConvertMaxMin final : OpConversionPattern { // Rewrite Max*I/Min*I as compare and select over original operands. Let // the CmpI and Select emulation patterns handle the final legalization. Value cmp = - rewriter.create(loc, CmpPred, op.getLhs(), op.getRhs()); + arith::CmpIOp::create(rewriter, loc, CmpPred, op.getLhs(), op.getRhs()); rewriter.replaceOpWithNewOp(op, cmp, op.getLhs(), op.getRhs()); return success(); @@ -587,7 +587,7 @@ struct ConvertIndexCastIndexToInt final : OpConversionPattern { // Sign or zero-extend the result. Let the matching conversion pattern // legalize the extension op. Value underlyingVal = - rewriter.create(loc, narrowTy, adaptor.getIn()); + CastOp::create(rewriter, loc, narrowTy, adaptor.getIn()); rewriter.replaceOpWithNewOp(op, resultType, underlyingVal); return success(); } @@ -616,9 +616,9 @@ struct ConvertSelect final : OpConversionPattern { Value cond = appendX1Dim(rewriter, loc, adaptor.getCondition()); Value resElem0 = - rewriter.create(loc, cond, trueElem0, falseElem0); + arith::SelectOp::create(rewriter, loc, cond, trueElem0, falseElem0); Value resElem1 = - rewriter.create(loc, cond, trueElem1, falseElem1); + arith::SelectOp::create(rewriter, loc, cond, trueElem1, falseElem1); Value resultVec = constructResultVector(rewriter, loc, newTy, {resElem0, resElem1}); rewriter.replaceOp(op, resultVec); @@ -680,33 +680,33 @@ struct ConvertShLI final : OpConversionPattern { Value elemBitWidth = createScalarOrSplatConstant(rewriter, loc, newOperandTy, newBitWidth); - Value illegalElemShift = rewriter.create( - loc, arith::CmpIPredicate::uge, rhsElem0, elemBitWidth); + Value illegalElemShift = arith::CmpIOp::create( + rewriter, loc, arith::CmpIPredicate::uge, rhsElem0, elemBitWidth); Value shiftedElem0 = - rewriter.create(loc, lhsElem0, rhsElem0); - Value resElem0 = rewriter.create(loc, illegalElemShift, - zeroCst, shiftedElem0); + arith::ShLIOp::create(rewriter, loc, lhsElem0, rhsElem0); + Value resElem0 = arith::SelectOp::create(rewriter, loc, illegalElemShift, + zeroCst, shiftedElem0); - Value cappedShiftAmount = rewriter.create( - loc, illegalElemShift, elemBitWidth, rhsElem0); + Value cappedShiftAmount = arith::SelectOp::create( + rewriter, loc, illegalElemShift, elemBitWidth, rhsElem0); Value rightShiftAmount = - rewriter.create(loc, elemBitWidth, cappedShiftAmount); + arith::SubIOp::create(rewriter, loc, elemBitWidth, cappedShiftAmount); Value shiftedRight = - rewriter.create(loc, lhsElem0, rightShiftAmount); + arith::ShRUIOp::create(rewriter, loc, lhsElem0, rightShiftAmount); Value overshotShiftAmount = - rewriter.create(loc, rhsElem0, elemBitWidth); + arith::SubIOp::create(rewriter, loc, rhsElem0, elemBitWidth); Value shiftedLeft = - rewriter.create(loc, lhsElem0, overshotShiftAmount); + arith::ShLIOp::create(rewriter, loc, lhsElem0, overshotShiftAmount); Value shiftedElem1 = - rewriter.create(loc, lhsElem1, rhsElem0); - Value resElem1High = rewriter.create( - loc, illegalElemShift, zeroCst, shiftedElem1); - Value resElem1Low = rewriter.create( - loc, illegalElemShift, shiftedLeft, shiftedRight); + arith::ShLIOp::create(rewriter, loc, lhsElem1, rhsElem0); + Value resElem1High = arith::SelectOp::create( + rewriter, loc, illegalElemShift, zeroCst, shiftedElem1); + Value resElem1Low = arith::SelectOp::create(rewriter, loc, illegalElemShift, + shiftedLeft, shiftedRight); Value resElem1 = - rewriter.create(loc, resElem1Low, resElem1High); + arith::OrIOp::create(rewriter, loc, resElem1Low, resElem1High); Value resultVec = constructResultVector(rewriter, loc, newTy, {resElem0, resElem1}); @@ -769,33 +769,33 @@ struct ConvertShRUI final : OpConversionPattern { Value elemBitWidth = createScalarOrSplatConstant(rewriter, loc, newOperandTy, newBitWidth); - Value illegalElemShift = rewriter.create( - loc, arith::CmpIPredicate::uge, rhsElem0, elemBitWidth); + Value illegalElemShift = arith::CmpIOp::create( + rewriter, loc, arith::CmpIPredicate::uge, rhsElem0, elemBitWidth); Value shiftedElem0 = - rewriter.create(loc, lhsElem0, rhsElem0); - Value resElem0Low = rewriter.create(loc, illegalElemShift, - zeroCst, shiftedElem0); + arith::ShRUIOp::create(rewriter, loc, lhsElem0, rhsElem0); + Value resElem0Low = arith::SelectOp::create(rewriter, loc, illegalElemShift, + zeroCst, shiftedElem0); Value shiftedElem1 = - rewriter.create(loc, lhsElem1, rhsElem0); - Value resElem1 = rewriter.create(loc, illegalElemShift, - zeroCst, shiftedElem1); + arith::ShRUIOp::create(rewriter, loc, lhsElem1, rhsElem0); + Value resElem1 = arith::SelectOp::create(rewriter, loc, illegalElemShift, + zeroCst, shiftedElem1); - Value cappedShiftAmount = rewriter.create( - loc, illegalElemShift, elemBitWidth, rhsElem0); + Value cappedShiftAmount = arith::SelectOp::create( + rewriter, loc, illegalElemShift, elemBitWidth, rhsElem0); Value leftShiftAmount = - rewriter.create(loc, elemBitWidth, cappedShiftAmount); + arith::SubIOp::create(rewriter, loc, elemBitWidth, cappedShiftAmount); Value shiftedLeft = - rewriter.create(loc, lhsElem1, leftShiftAmount); + arith::ShLIOp::create(rewriter, loc, lhsElem1, leftShiftAmount); Value overshotShiftAmount = - rewriter.create(loc, rhsElem0, elemBitWidth); + arith::SubIOp::create(rewriter, loc, rhsElem0, elemBitWidth); Value shiftedRight = - rewriter.create(loc, lhsElem1, overshotShiftAmount); + arith::ShRUIOp::create(rewriter, loc, lhsElem1, overshotShiftAmount); - Value resElem0High = rewriter.create( - loc, illegalElemShift, shiftedRight, shiftedLeft); + Value resElem0High = arith::SelectOp::create( + rewriter, loc, illegalElemShift, shiftedRight, shiftedLeft); Value resElem0 = - rewriter.create(loc, resElem0Low, resElem0High); + arith::OrIOp::create(rewriter, loc, resElem0Low, resElem0High); Value resultVec = constructResultVector(rewriter, loc, newTy, {resElem0, resElem1}); @@ -832,33 +832,33 @@ struct ConvertShRSI final : OpConversionPattern { // Perform as many ops over the narrow integer type as possible and let the // other emulation patterns convert the rest. Value elemZero = createScalarOrSplatConstant(rewriter, loc, narrowTy, 0); - Value signBit = rewriter.create( - loc, arith::CmpIPredicate::slt, lhsElem1, elemZero); + Value signBit = arith::CmpIOp::create( + rewriter, loc, arith::CmpIPredicate::slt, lhsElem1, elemZero); signBit = dropTrailingX1Dim(rewriter, loc, signBit); // Create a bit pattern of either all ones or all zeros. Then shift it left // to calculate the sign extension bits created by shifting the original // sign bit right. - Value allSign = rewriter.create(loc, oldTy, signBit); + Value allSign = arith::ExtSIOp::create(rewriter, loc, oldTy, signBit); Value maxShift = createScalarOrSplatConstant(rewriter, loc, narrowTy, origBitwidth); Value numNonSignExtBits = - rewriter.create(loc, maxShift, rhsElem0); + arith::SubIOp::create(rewriter, loc, maxShift, rhsElem0); numNonSignExtBits = dropTrailingX1Dim(rewriter, loc, numNonSignExtBits); numNonSignExtBits = - rewriter.create(loc, oldTy, numNonSignExtBits); + arith::ExtUIOp::create(rewriter, loc, oldTy, numNonSignExtBits); Value signBits = - rewriter.create(loc, allSign, numNonSignExtBits); + arith::ShLIOp::create(rewriter, loc, allSign, numNonSignExtBits); // Use original arguments to create the right shift. Value shrui = - rewriter.create(loc, op.getLhs(), op.getRhs()); - Value shrsi = rewriter.create(loc, shrui, signBits); + arith::ShRUIOp::create(rewriter, loc, op.getLhs(), op.getRhs()); + Value shrsi = arith::OrIOp::create(rewriter, loc, shrui, signBits); // Handle shifting by zero. This is necessary when the `signBits` shift is // invalid. - Value isNoop = rewriter.create(loc, arith::CmpIPredicate::eq, - rhsElem0, elemZero); + Value isNoop = arith::CmpIOp::create( + rewriter, loc, arith::CmpIPredicate::eq, rhsElem0, elemZero); isNoop = dropTrailingX1Dim(rewriter, loc, isNoop); rewriter.replaceOpWithNewOp(op, isNoop, op.getLhs(), shrsi); @@ -892,14 +892,14 @@ struct ConvertSubI final : OpConversionPattern { // Emulates LHS - RHS by [LHS0 - RHS0, LHS1 - RHS1 - CARRY] where // CARRY is 1 or 0. - Value low = rewriter.create(loc, lhsElem0, rhsElem0); + Value low = arith::SubIOp::create(rewriter, loc, lhsElem0, rhsElem0); // We have a carry if lhsElem0 < rhsElem0. - Value carry0 = rewriter.create( - loc, arith::CmpIPredicate::ult, lhsElem0, rhsElem0); - Value carryVal = rewriter.create(loc, newElemTy, carry0); + Value carry0 = arith::CmpIOp::create( + rewriter, loc, arith::CmpIPredicate::ult, lhsElem0, rhsElem0); + Value carryVal = arith::ExtUIOp::create(rewriter, loc, newElemTy, carry0); - Value high0 = rewriter.create(loc, lhsElem1, carryVal); - Value high = rewriter.create(loc, high0, rhsElem1); + Value high0 = arith::SubIOp::create(rewriter, loc, lhsElem1, carryVal); + Value high = arith::SubIOp::create(rewriter, loc, high0, rhsElem1); Value resultVec = constructResultVector(rewriter, loc, newTy, {low, high}); rewriter.replaceOp(op, resultVec); @@ -933,13 +933,13 @@ struct ConvertSIToFP final : OpConversionPattern { // result or not based on that sign bit. We implement negation by // subtracting from zero. Note that this relies on the the other conversion // patterns to legalize created ops and narrow the bit widths. - Value isNeg = rewriter.create(loc, arith::CmpIPredicate::slt, - in, zeroCst); - Value neg = rewriter.create(loc, zeroCst, in); - Value abs = rewriter.create(loc, isNeg, neg, in); + Value isNeg = arith::CmpIOp::create(rewriter, loc, + arith::CmpIPredicate::slt, in, zeroCst); + Value neg = arith::SubIOp::create(rewriter, loc, zeroCst, in); + Value abs = arith::SelectOp::create(rewriter, loc, isNeg, neg, in); - Value absResult = rewriter.create(loc, op.getType(), abs); - Value negResult = rewriter.create(loc, absResult); + Value absResult = arith::UIToFPOp::create(rewriter, loc, op.getType(), abs); + Value negResult = arith::NegFOp::create(rewriter, loc, absResult); rewriter.replaceOpWithNewOp(op, isNeg, negResult, absResult); return success(); @@ -985,13 +985,13 @@ struct ConvertUIToFP final : OpConversionPattern { // // Note 2: We do not strictly need the `hi == 0`, case, but it makes // constant folding easier. - Value hiEqZero = rewriter.create( - loc, arith::CmpIPredicate::eq, hiInt, zeroCst); + Value hiEqZero = arith::CmpIOp::create( + rewriter, loc, arith::CmpIPredicate::eq, hiInt, zeroCst); Type resultTy = op.getType(); Type resultElemTy = getElementTypeOrSelf(resultTy); - Value lowFp = rewriter.create(loc, resultTy, lowInt); - Value hiFp = rewriter.create(loc, resultTy, hiInt); + Value lowFp = arith::UIToFPOp::create(rewriter, loc, resultTy, lowInt); + Value hiFp = arith::UIToFPOp::create(rewriter, loc, resultTy, hiInt); int64_t pow2Int = int64_t(1) << newBitWidth; TypedAttr pow2Attr = @@ -999,10 +999,11 @@ struct ConvertUIToFP final : OpConversionPattern { if (auto vecTy = dyn_cast(resultTy)) pow2Attr = SplatElementsAttr::get(vecTy, pow2Attr); - Value pow2Val = rewriter.create(loc, resultTy, pow2Attr); + Value pow2Val = + arith::ConstantOp::create(rewriter, loc, resultTy, pow2Attr); - Value hiVal = rewriter.create(loc, hiFp, pow2Val); - Value result = rewriter.create(loc, lowFp, hiVal); + Value hiVal = arith::MulFOp::create(rewriter, loc, hiFp, pow2Val); + Value result = arith::AddFOp::create(rewriter, loc, lowFp, hiVal); rewriter.replaceOpWithNewOp(op, hiEqZero, lowFp, result); return success(); @@ -1037,22 +1038,22 @@ struct ConvertFPToSI final : OpConversionPattern { // result is UB. TypedAttr zeroAttr = rewriter.getZeroAttr(fpTy); - Value zeroCst = rewriter.create(loc, zeroAttr); + Value zeroCst = arith::ConstantOp::create(rewriter, loc, zeroAttr); Value zeroCstInt = createScalarOrSplatConstant(rewriter, loc, intTy, 0); // Get the absolute value. One could have used math.absf here, but that // introduces an extra dependency. - Value isNeg = rewriter.create(loc, arith::CmpFPredicate::OLT, - inFp, zeroCst); - Value negInFp = rewriter.create(loc, inFp); + Value isNeg = arith::CmpFOp::create( + rewriter, loc, arith::CmpFPredicate::OLT, inFp, zeroCst); + Value negInFp = arith::NegFOp::create(rewriter, loc, inFp); - Value absVal = rewriter.create(loc, isNeg, negInFp, inFp); + Value absVal = arith::SelectOp::create(rewriter, loc, isNeg, negInFp, inFp); // Defer the absolute value to fptoui. - Value res = rewriter.create(loc, intTy, absVal); + Value res = arith::FPToUIOp::create(rewriter, loc, intTy, absVal); // Negate the value if < 0 . - Value neg = rewriter.create(loc, zeroCstInt, res); + Value neg = arith::SubIOp::create(rewriter, loc, zeroCstInt, res); rewriter.replaceOpWithNewOp(op, isNeg, neg, res); return success(); @@ -1109,17 +1110,17 @@ struct ConvertFPToUI final : OpConversionPattern { if (auto vecType = dyn_cast(fpTy)) powBitwidthAttr = SplatElementsAttr::get(vecType, powBitwidthAttr); Value powBitwidthFloatCst = - rewriter.create(loc, powBitwidthAttr); + arith::ConstantOp::create(rewriter, loc, powBitwidthAttr); Value fpDivPowBitwidth = - rewriter.create(loc, inFp, powBitwidthFloatCst); + arith::DivFOp::create(rewriter, loc, inFp, powBitwidthFloatCst); Value resHigh = - rewriter.create(loc, newHalfType, fpDivPowBitwidth); + arith::FPToUIOp::create(rewriter, loc, newHalfType, fpDivPowBitwidth); // Calculate fp - resHigh * 2^N by getting the remainder of the division Value remainder = - rewriter.create(loc, inFp, powBitwidthFloatCst); + arith::RemFOp::create(rewriter, loc, inFp, powBitwidthFloatCst); Value resLow = - rewriter.create(loc, newHalfType, remainder); + arith::FPToUIOp::create(rewriter, loc, newHalfType, remainder); Value high = appendX1Dim(rewriter, loc, resHigh); Value low = appendX1Dim(rewriter, loc, resLow); diff --git a/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp b/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp index e842f44b3b97f..f8fa35c6fa7de 100644 --- a/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp +++ b/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp @@ -28,10 +28,10 @@ static Value createConst(Location loc, Type type, int value, PatternRewriter &rewriter) { auto attr = rewriter.getIntegerAttr(getElementTypeOrSelf(type), value); if (auto shapedTy = dyn_cast(type)) { - return rewriter.create( - loc, DenseElementsAttr::get(shapedTy, attr)); + return arith::ConstantOp::create(rewriter, loc, + DenseElementsAttr::get(shapedTy, attr)); } - return rewriter.create(loc, attr); + return arith::ConstantOp::create(rewriter, loc, attr); } /// Create a float constant. @@ -39,11 +39,11 @@ static Value createFloatConst(Location loc, Type type, APFloat value, PatternRewriter &rewriter) { auto attr = rewriter.getFloatAttr(getElementTypeOrSelf(type), value); if (auto shapedTy = dyn_cast(type)) { - return rewriter.create( - loc, DenseElementsAttr::get(shapedTy, attr)); + return arith::ConstantOp::create(rewriter, loc, + DenseElementsAttr::get(shapedTy, attr)); } - return rewriter.create(loc, attr); + return arith::ConstantOp::create(rewriter, loc, attr); } /// Creates shapedType using shape from cloneFrom and base type from cloneTo @@ -67,11 +67,11 @@ struct CeilDivUIOpConverter : public OpRewritePattern { Value b = op.getRhs(); Value zero = createConst(loc, a.getType(), 0, rewriter); Value compare = - rewriter.create(loc, arith::CmpIPredicate::eq, a, zero); + arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::eq, a, zero); Value one = createConst(loc, a.getType(), 1, rewriter); - Value minusOne = rewriter.create(loc, a, one); - Value quotient = rewriter.create(loc, minusOne, b); - Value plusOne = rewriter.create(loc, quotient, one); + Value minusOne = arith::SubIOp::create(rewriter, loc, a, one); + Value quotient = arith::DivUIOp::create(rewriter, loc, minusOne, b); + Value plusOne = arith::AddIOp::create(rewriter, loc, quotient, one); rewriter.replaceOpWithNewOp(op, compare, zero, plusOne); return success(); } @@ -96,22 +96,22 @@ struct CeilDivSIOpConverter : public OpRewritePattern { Value zero = createConst(loc, type, 0, rewriter); Value one = createConst(loc, type, 1, rewriter); - Value quotient = rewriter.create(loc, a, b); - Value product = rewriter.create(loc, quotient, b); - Value notEqualDivisor = rewriter.create( - loc, arith::CmpIPredicate::ne, a, product); + Value quotient = arith::DivSIOp::create(rewriter, loc, a, b); + Value product = arith::MulIOp::create(rewriter, loc, quotient, b); + Value notEqualDivisor = arith::CmpIOp::create( + rewriter, loc, arith::CmpIPredicate::ne, a, product); - Value aNeg = - rewriter.create(loc, arith::CmpIPredicate::slt, a, zero); - Value bNeg = - rewriter.create(loc, arith::CmpIPredicate::slt, b, zero); + Value aNeg = arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::slt, + a, zero); + Value bNeg = arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::slt, + b, zero); - Value signEqual = rewriter.create( - loc, arith::CmpIPredicate::eq, aNeg, bNeg); + Value signEqual = arith::CmpIOp::create( + rewriter, loc, arith::CmpIPredicate::eq, aNeg, bNeg); Value cond = - rewriter.create(loc, notEqualDivisor, signEqual); + arith::AndIOp::create(rewriter, loc, notEqualDivisor, signEqual); - Value quotientPlusOne = rewriter.create(loc, quotient, one); + Value quotientPlusOne = arith::AddIOp::create(rewriter, loc, quotient, one); rewriter.replaceOpWithNewOp(op, cond, quotientPlusOne, quotient); @@ -135,25 +135,25 @@ struct FloorDivSIOpConverter : public OpRewritePattern { Value a = op.getLhs(); Value b = op.getRhs(); - Value quotient = rewriter.create(loc, a, b); - Value product = rewriter.create(loc, quotient, b); - Value notEqualDivisor = rewriter.create( - loc, arith::CmpIPredicate::ne, a, product); + Value quotient = arith::DivSIOp::create(rewriter, loc, a, b); + Value product = arith::MulIOp::create(rewriter, loc, quotient, b); + Value notEqualDivisor = arith::CmpIOp::create( + rewriter, loc, arith::CmpIPredicate::ne, a, product); Value zero = createConst(loc, type, 0, rewriter); - Value aNeg = - rewriter.create(loc, arith::CmpIPredicate::slt, a, zero); - Value bNeg = - rewriter.create(loc, arith::CmpIPredicate::slt, b, zero); + Value aNeg = arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::slt, + a, zero); + Value bNeg = arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::slt, + b, zero); - Value signOpposite = rewriter.create( - loc, arith::CmpIPredicate::ne, aNeg, bNeg); + Value signOpposite = arith::CmpIOp::create( + rewriter, loc, arith::CmpIPredicate::ne, aNeg, bNeg); Value cond = - rewriter.create(loc, notEqualDivisor, signOpposite); + arith::AndIOp::create(rewriter, loc, notEqualDivisor, signOpposite); Value minusOne = createConst(loc, type, -1, rewriter); Value quotientMinusOne = - rewriter.create(loc, quotient, minusOne); + arith::AddIOp::create(rewriter, loc, quotient, minusOne); rewriter.replaceOpWithNewOp(op, cond, quotientMinusOne, quotient); @@ -171,7 +171,7 @@ struct MaxMinIOpConverter : public OpRewritePattern { Value lhs = op.getLhs(); Value rhs = op.getRhs(); - Value cmp = rewriter.create(op.getLoc(), pred, lhs, rhs); + Value cmp = arith::CmpIOp::create(rewriter, op.getLoc(), pred, lhs, rhs); rewriter.replaceOpWithNewOp(op, cmp, lhs, rhs); return success(); } @@ -192,12 +192,12 @@ struct MaximumMinimumFOpConverter : public OpRewritePattern { static_assert(pred == arith::CmpFPredicate::UGT || pred == arith::CmpFPredicate::ULT, "pred must be either UGT or ULT"); - Value cmp = rewriter.create(loc, pred, lhs, rhs); - Value select = rewriter.create(loc, cmp, lhs, rhs); + Value cmp = arith::CmpFOp::create(rewriter, loc, pred, lhs, rhs); + Value select = arith::SelectOp::create(rewriter, loc, cmp, lhs, rhs); // Handle the case where rhs is NaN: 'isNaN(rhs) ? rhs : select'. - Value isNaN = rewriter.create(loc, arith::CmpFPredicate::UNO, - rhs, rhs); + Value isNaN = arith::CmpFOp::create(rewriter, loc, + arith::CmpFPredicate::UNO, rhs, rhs); rewriter.replaceOpWithNewOp(op, isNaN, rhs, select); return success(); } @@ -218,12 +218,12 @@ struct MaxNumMinNumFOpConverter : public OpRewritePattern { static_assert(pred == arith::CmpFPredicate::UGT || pred == arith::CmpFPredicate::ULT, "pred must be either UGT or ULT"); - Value cmp = rewriter.create(loc, pred, lhs, rhs); - Value select = rewriter.create(loc, cmp, lhs, rhs); + Value cmp = arith::CmpFOp::create(rewriter, loc, pred, lhs, rhs); + Value select = arith::SelectOp::create(rewriter, loc, cmp, lhs, rhs); // Handle the case where lhs is NaN: 'isNaN(lhs) ? rhs : select'. - Value isNaN = rewriter.create(loc, arith::CmpFPredicate::UNO, - lhs, lhs); + Value isNaN = arith::CmpFOp::create(rewriter, loc, + arith::CmpFPredicate::UNO, lhs, lhs); rewriter.replaceOpWithNewOp(op, isNaN, rhs, select); return success(); } @@ -247,12 +247,12 @@ struct BFloat16ExtFOpConverter : public OpRewritePattern { Type i16Ty = cloneToShapedType(operandTy, b.getI16Type()); Type i32Ty = cloneToShapedType(operandTy, b.getI32Type()); - Value bitcast = b.create(i16Ty, operand); - Value exti = b.create(i32Ty, bitcast); + Value bitcast = arith::BitcastOp::create(b, i16Ty, operand); + Value exti = arith::ExtUIOp::create(b, i32Ty, bitcast); Value c16 = createConst(op.getLoc(), i32Ty, 16, rewriter); - Value shl = b.create(exti, c16); - Value result = b.create(resultTy, shl); + Value shl = arith::ShLIOp::create(b, exti, c16); + Value result = arith::BitcastOp::create(b, resultTy, shl); rewriter.replaceOp(op, result); return success(); @@ -296,7 +296,7 @@ struct BFloat16TruncFOpConverter : public OpRewritePattern { // exponent bits, that simple truncation is the desired outcome for // infinities. Value isNan = - b.create(arith::CmpFPredicate::UNE, operand, operand); + arith::CmpFOp::create(b, arith::CmpFPredicate::UNE, operand, operand); // Constant used to make the rounding bias. Value c7FFF = createConst(op.getLoc(), i32Ty, 0x7fff, rewriter); // Constant used to generate a quiet NaN. @@ -305,30 +305,30 @@ struct BFloat16TruncFOpConverter : public OpRewritePattern { Value c16 = createConst(op.getLoc(), i32Ty, 16, rewriter); Value c1 = createConst(op.getLoc(), i32Ty, 1, rewriter); // Reinterpret the input f32 value as bits. - Value bitcast = b.create(i32Ty, operand); + Value bitcast = arith::BitcastOp::create(b, i32Ty, operand); // Read bit 16 as a value in {0,1}. Value bit16 = - b.create(b.create(bitcast, c16), c1); + arith::AndIOp::create(b, arith::ShRUIOp::create(b, bitcast, c16), c1); // Determine the rounding bias to add as either 0x7fff or 0x8000 depending // on bit 16, implementing the tie-breaking "to nearest even". - Value roundingBias = b.create(bit16, c7FFF); + Value roundingBias = arith::AddIOp::create(b, bit16, c7FFF); // Add the rounding bias. Generally we want this to be added to the // mantissa, but nothing prevents this to from carrying into the exponent // bits, which would feel like a bug, but this is the magic trick here: // when that happens, the mantissa gets reset to zero and the exponent // gets incremented by the carry... which is actually exactly what we // want. - Value biased = b.create(bitcast, roundingBias); + Value biased = arith::AddIOp::create(b, bitcast, roundingBias); // Now that the rounding-bias has been added, truncating the low bits // yields the correctly rounded result. - Value biasedAndShifted = b.create(biased, c16); + Value biasedAndShifted = arith::ShRUIOp::create(b, biased, c16); Value normalCaseResultI16 = - b.create(i16Ty, biasedAndShifted); + arith::TruncIOp::create(b, i16Ty, biasedAndShifted); // Select either the above-computed result, or a quiet NaN constant // if the input was NaN. Value select = - b.create(isNan, c7FC0I16, normalCaseResultI16); - Value result = b.create(resultTy, select); + arith::SelectOp::create(b, isNan, c7FC0I16, normalCaseResultI16); + Value result = arith::BitcastOp::create(b, resultTy, select); rewriter.replaceOp(op, result); return success(); } @@ -381,7 +381,7 @@ struct F4E2M1ExtFOpConverter : public OpRewritePattern { Type f32Ty = cloneToShapedType(operandTy, b.getF32Type()); Type i4Ty = cloneToShapedType(operandTy, b.getI4Type()); Type i32Ty = cloneToShapedType(operandTy, b.getI32Type()); - Value i4Bits = b.create(i4Ty, operand); + Value i4Bits = arith::BitcastOp::create(b, i4Ty, operand); Value c0x0 = createConst(loc, i4Ty, 0x0, rewriter); Value c0x1 = createConst(loc, i4Ty, 0x1, rewriter); @@ -390,38 +390,39 @@ struct F4E2M1ExtFOpConverter : public OpRewritePattern { // Set last Exponent bit and Mantissa. Value c0x00000014 = createConst(loc, i32Ty, 0x14, rewriter); - Value bits1To24 = b.create(i4Bits, c0x2); + Value bits1To24 = arith::ShLIOp::create(b, i4Bits, c0x2); Value isHalf = - b.create(arith::CmpIPredicate::eq, i4Bits, c0x1); - bits1To24 = b.create(isHalf, c0x0, bits1To24); - bits1To24 = b.create(i32Ty, bits1To24); - bits1To24 = b.create(bits1To24, c0x00000014); + arith::CmpIOp::create(b, arith::CmpIPredicate::eq, i4Bits, c0x1); + bits1To24 = arith::SelectOp::create(b, isHalf, c0x0, bits1To24); + bits1To24 = arith::ExtUIOp::create(b, i32Ty, bits1To24); + bits1To24 = arith::ShLIOp::create(b, bits1To24, c0x00000014); // Set first 7 bits of Exponent. Value zeroExpBits = createConst(loc, i32Ty, 0x00000000, rewriter); Value highExpBits = createConst(loc, i32Ty, 0x40000000, rewriter); Value lowExpBits = createConst(loc, i32Ty, 0x3f000000, rewriter); Value useLargerExp = - b.create(arith::CmpIPredicate::uge, i4Bits, c0x4); + arith::CmpIOp::create(b, arith::CmpIPredicate::uge, i4Bits, c0x4); Value bits25To31 = - b.create(useLargerExp, highExpBits, lowExpBits); + arith::SelectOp::create(b, useLargerExp, highExpBits, lowExpBits); Value zeroExp = - b.create(arith::CmpIPredicate::eq, i4Bits, c0x0); - bits25To31 = b.create(zeroExp, zeroExpBits, bits25To31); + arith::CmpIOp::create(b, arith::CmpIPredicate::eq, i4Bits, c0x0); + bits25To31 = arith::SelectOp::create(b, zeroExp, zeroExpBits, bits25To31); // Set sign. Value c0x80000000 = createConst(loc, i32Ty, 0x80000000, rewriter); Value c0x8 = createConst(loc, i4Ty, 0x8, rewriter); Value negative = - b.create(arith::CmpIPredicate::uge, i4Bits, c0x8); - Value bit32 = b.create(negative, c0x80000000, zeroExpBits); + arith::CmpIOp::create(b, arith::CmpIPredicate::uge, i4Bits, c0x8); + Value bit32 = + arith::SelectOp::create(b, negative, c0x80000000, zeroExpBits); // Add segments together. - Value bits1To31 = b.create(bits1To24, bits25To31); - Value bits1To32 = b.create(bits1To31, bit32); - Value result = b.create(f32Ty, bits1To32); + Value bits1To31 = arith::AddIOp::create(b, bits1To24, bits25To31); + Value bits1To32 = arith::AddIOp::create(b, bits1To31, bit32); + Value result = arith::BitcastOp::create(b, f32Ty, bits1To32); if (!isa(resultETy)) - result = b.create(resultTy, result); + result = arith::TruncFOp::create(b, resultTy, result); rewriter.replaceOp(op, result); return success(); @@ -447,25 +448,25 @@ struct F8E8M0ExtFOpConverter : public OpRewritePattern { Type i32Ty = cloneToShapedType(operandTy, b.getI32Type()); Type f32Ty = cloneToShapedType(operandTy, b.getF32Type()); - Value bitcast = b.create(i8Ty, operand); + Value bitcast = arith::BitcastOp::create(b, i8Ty, operand); // create constants for NaNs Value cF8NaN = createConst(op.getLoc(), i8Ty, 0xff, rewriter); Value cF32NaN = createConst(op.getLoc(), i32Ty, 0xffffffff, rewriter); Value cF32MantissaWidth = createConst(op->getLoc(), i32Ty, 23, rewriter); - Value exti = b.create(i32Ty, bitcast); - Value f32Bits = b.create(exti, cF32MantissaWidth); + Value exti = arith::ExtUIOp::create(b, i32Ty, bitcast); + Value f32Bits = arith::ShLIOp::create(b, exti, cF32MantissaWidth); Value isNan = - b.create(arith::CmpIPredicate::eq, bitcast, cF8NaN); + arith::CmpIOp::create(b, arith::CmpIPredicate::eq, bitcast, cF8NaN); // select for NaNs - f32Bits = b.create(isNan, cF32NaN, f32Bits); - Value result = b.create(f32Ty, f32Bits); + f32Bits = arith::SelectOp::create(b, isNan, cF32NaN, f32Bits); + Value result = arith::BitcastOp::create(b, f32Ty, f32Bits); if (resultETy.getIntOrFloatBitWidth() < 32) { - result = b.create(resultTy, result, nullptr, - op.getFastmathAttr()); + result = arith::TruncFOp::create(b, resultTy, result, nullptr, + op.getFastmathAttr()); } else if (resultETy.getIntOrFloatBitWidth() > 32) { - result = b.create(resultTy, result, op.getFastmathAttr()); + result = arith::ExtFOp::create(b, resultTy, result, op.getFastmathAttr()); } rewriter.replaceOp(op, result); return success(); @@ -520,7 +521,7 @@ struct F4E2M1TruncFOpConverter : public OpRewritePattern { if (!isa(resultETy)) return rewriter.notifyMatchFailure(op, "not a trunc of F4E2M1FN"); if (!isa(operandETy)) - operand = b.create(f32Ty, operand); + operand = arith::ExtFOp::create(b, f32Ty, operand); Value c0x1 = createConst(loc, i4Ty, 1, rewriter); Value c0x3 = createConst(loc, i4Ty, 3, rewriter); @@ -532,65 +533,65 @@ struct F4E2M1TruncFOpConverter : public OpRewritePattern { // Step 0: Clamp to bounds. Value cHigherBound = createFloatConst(loc, f32Ty, APFloat(6.0f), rewriter); Value cLowerBound = createFloatConst(loc, f32Ty, APFloat(-6.0f), rewriter); - Value operandClamped = b.create(cHigherBound, operand); - operandClamped = b.create(cLowerBound, operandClamped); - Value f32Bits = b.create(i32Ty, operandClamped); + Value operandClamped = arith::MinNumFOp::create(b, cHigherBound, operand); + operandClamped = arith::MaxNumFOp::create(b, cLowerBound, operandClamped); + Value f32Bits = arith::BitcastOp::create(b, i32Ty, operandClamped); // Step 1: Set sign bit. Value cF32ExpManWidth = createConst(loc, i32Ty, 31, rewriter); // 23 - Value f32Sign = b.create(f32Bits, cF32ExpManWidth); - Value f4Sign = b.create(i4Ty, f32Sign); - Value f4Bits = b.create(f4Sign, c0x3); + Value f32Sign = arith::ShRUIOp::create(b, f32Bits, cF32ExpManWidth); + Value f4Sign = arith::TruncIOp::create(b, i4Ty, f32Sign); + Value f4Bits = arith::ShLIOp::create(b, f4Sign, c0x3); // Step 2: Convert exponent by adjusting bias. Value biasAdjustment = createConst(loc, i32Ty, 0x7e, rewriter); Value cF4MantissaWidth = c0x1; // 1 Value cF32MantissaWidth = createConst(loc, i32Ty, 23, rewriter); // 23 - Value f32SignExp = b.create(f32Bits, cF32MantissaWidth); + Value f32SignExp = arith::ShRUIOp::create(b, f32Bits, cF32MantissaWidth); Value biasAdjustedSignExp = - b.create(f32SignExp, biasAdjustment); - Value f4Exp = b.create(i4Ty, biasAdjustedSignExp); - f4Exp = b.create(f4Exp, cF4MantissaWidth); - f4Bits = b.create(f4Bits, f4Exp); + arith::SubIOp::create(b, f32SignExp, biasAdjustment); + Value f4Exp = arith::TruncIOp::create(b, i4Ty, biasAdjustedSignExp); + f4Exp = arith::ShLIOp::create(b, f4Exp, cF4MantissaWidth); + f4Bits = arith::AddIOp::create(b, f4Bits, f4Exp); // Step 3: Set mantissa to first bit. Value cF32FirstBitMask = createConst(loc, i32Ty, 0x400000, rewriter); - Value man1Bit = b.create(f32Bits, cF32FirstBitMask); - man1Bit = b.create(man1Bit, c0x00000016); - Value f4Man = b.create(i4Ty, man1Bit); - f4Bits = b.create(f4Bits, f4Man); + Value man1Bit = arith::AndIOp::create(b, f32Bits, cF32FirstBitMask); + man1Bit = arith::ShRUIOp::create(b, man1Bit, c0x00000016); + Value f4Man = arith::TruncIOp::create(b, i4Ty, man1Bit); + f4Bits = arith::AddIOp::create(b, f4Bits, f4Man); // Step 4: Special consideration for conversion to 0.5. Value cF32MantissaMask = createConst(loc, i32Ty, 0x7fffff, rewriter); - Value f8Exp = b.create(i8Ty, biasAdjustedSignExp); + Value f8Exp = arith::TruncIOp::create(b, i8Ty, biasAdjustedSignExp); Value isSubnormal = - b.create(arith::CmpIPredicate::sle, f8Exp, c0x00); + arith::CmpIOp::create(b, arith::CmpIPredicate::sle, f8Exp, c0x00); Value isNegOneExp = - b.create(arith::CmpIPredicate::eq, f8Exp, c0xff); - Value man23Bits = b.create(f32Bits, cF32MantissaMask); - Value isNonZeroMan = b.create(arith::CmpIPredicate::ugt, - man23Bits, zeroExpBits); - Value roundToHalf = b.create(isNegOneExp, isNonZeroMan); + arith::CmpIOp::create(b, arith::CmpIPredicate::eq, f8Exp, c0xff); + Value man23Bits = arith::AndIOp::create(b, f32Bits, cF32MantissaMask); + Value isNonZeroMan = arith::CmpIOp::create(b, arith::CmpIPredicate::ugt, + man23Bits, zeroExpBits); + Value roundToHalf = arith::AndIOp::create(b, isNegOneExp, isNonZeroMan); Value isZeroExp = - b.create(arith::CmpIPredicate::eq, f8Exp, c0x00); + arith::CmpIOp::create(b, arith::CmpIPredicate::eq, f8Exp, c0x00); Value subnormalF4Bits = createConst(loc, i4Ty, 0xf, rewriter); Value halfF4Bits = createConst(loc, i4Ty, 0x0, rewriter); Value subResult = - b.create(isSubnormal, subnormalF4Bits, f4Bits); - subResult = b.create(roundToHalf, halfF4Bits, subResult); - f4Bits = b.create(isZeroExp, f4Bits, subResult); + arith::SelectOp::create(b, isSubnormal, subnormalF4Bits, f4Bits); + subResult = arith::SelectOp::create(b, roundToHalf, halfF4Bits, subResult); + f4Bits = arith::SelectOp::create(b, isZeroExp, f4Bits, subResult); // Step 5: Round up if necessary. Value cF32Last22BitMask = createConst(loc, i32Ty, 0x3fffff, rewriter); Value cRound = createConst(loc, i32Ty, 0x200000, rewriter); // 010 0000... - Value man22Bits = b.create(f32Bits, cF32Last22BitMask); + Value man22Bits = arith::AndIOp::create(b, f32Bits, cF32Last22BitMask); Value shouldRound = - b.create(arith::CmpIPredicate::uge, man22Bits, cRound); - shouldRound = b.create(shouldRound, isSubnormal); - Value roundedF4Bits = b.create(f4Bits, c0x1); - f4Bits = b.create(shouldRound, roundedF4Bits, f4Bits); + arith::CmpIOp::create(b, arith::CmpIPredicate::uge, man22Bits, cRound); + shouldRound = arith::OrIOp::create(b, shouldRound, isSubnormal); + Value roundedF4Bits = arith::AddIOp::create(b, f4Bits, c0x1); + f4Bits = arith::SelectOp::create(b, shouldRound, roundedF4Bits, f4Bits); - Value result = b.create(resultTy, f4Bits); + Value result = arith::BitcastOp::create(b, resultTy, f4Bits); rewriter.replaceOp(op, result); return success(); } @@ -625,16 +626,16 @@ struct F8E8M0TruncFOpConverter : public OpRewritePattern { Type f32Ty = cloneToShapedType(operandTy, b.getF32Type()); if (operandETy.getIntOrFloatBitWidth() < 32) { - operand = b.create(f32Ty, operand, op.getFastmathAttr()); + operand = arith::ExtFOp::create(b, f32Ty, operand, op.getFastmathAttr()); } else if (operandETy.getIntOrFloatBitWidth() > 32) { - operand = b.create( - f32Ty, operand, op.getRoundingmodeAttr(), op.getFastmathAttr()); + operand = arith::TruncFOp::create( + b, f32Ty, operand, op.getRoundingmodeAttr(), op.getFastmathAttr()); } - Value f32Bits = b.create(i32Ty, operand); + Value f32Bits = arith::BitcastOp::create(b, i32Ty, operand); Value cF32MantissaWidth = createConst(op->getLoc(), i32Ty, 23, rewriter); - Value f32SignExp = b.create(f32Bits, cF32MantissaWidth); - Value exp8Bits = b.create(i8Ty, f32SignExp); - Value result = b.create(resultTy, exp8Bits); + Value f32SignExp = arith::ShRUIOp::create(b, f32Bits, cF32MantissaWidth); + Value exp8Bits = arith::TruncIOp::create(b, i8Ty, f32SignExp); + Value result = arith::BitcastOp::create(b, resultTy, exp8Bits); rewriter.replaceOp(op, result); return success(); } @@ -653,8 +654,8 @@ struct ScalingExtFOpConverter : public OpRewritePattern { if (scaleETy.getIntOrFloatBitWidth() >= 16) { scaleETy = b.getF8E8M0Type(); scaleTy = cloneToShapedType(scaleTy, scaleETy); - scaleOperand = b.create(scaleTy, scaleOperand, nullptr, - op.getFastmathAttr()); + scaleOperand = arith::TruncFOp::create(b, scaleTy, scaleOperand, nullptr, + op.getFastmathAttr()); } // Catch scale types like f8E5M2. if (!llvm::isa(scaleETy)) { @@ -666,11 +667,11 @@ struct ScalingExtFOpConverter : public OpRewritePattern { // extf on scale will essentially create floating point number // of type resulTy that is 2^scale and will also propagate NaNs Value scaleExt = - b.create(resultTy, scaleOperand, op.getFastmathAttr()); + arith::ExtFOp::create(b, resultTy, scaleOperand, op.getFastmathAttr()); Value inputExt = - b.create(resultTy, inputOperand, op.getFastmathAttr()); + arith::ExtFOp::create(b, resultTy, inputOperand, op.getFastmathAttr()); Value result = - b.create(inputExt, scaleExt, op.getFastmathAttr()); + arith::MulFOp::create(b, inputExt, scaleExt, op.getFastmathAttr()); rewriter.replaceOp(op, result); return success(); } @@ -695,8 +696,8 @@ struct ScalingTruncFOpConverter if (scaleETy.getIntOrFloatBitWidth() >= 16) { scaleETy = b.getF8E8M0Type(); scaleTy = cloneToShapedType(scaleTy, scaleETy); - scaleOperand = b.create(scaleTy, scaleOperand, nullptr, - op.getFastmathAttr()); + scaleOperand = arith::TruncFOp::create(b, scaleTy, scaleOperand, nullptr, + op.getFastmathAttr()); } if (!llvm::isa(scaleETy)) { return rewriter.notifyMatchFailure( @@ -708,11 +709,11 @@ struct ScalingTruncFOpConverter // this will create a floating point number of type // inputTy that is 2^scale and will also propagate NaNs scaleOperand = - b.create(inputTy, scaleOperand, op.getFastmathAttr()); - Value result = b.create(inputOperand, scaleOperand, - op.getFastmathAttr()); - Value resultCast = b.create( - resultTy, result, op.getRoundingmodeAttr(), op.getFastmathAttr()); + arith::ExtFOp::create(b, inputTy, scaleOperand, op.getFastmathAttr()); + Value result = arith::DivFOp::create(b, inputOperand, scaleOperand, + op.getFastmathAttr()); + Value resultCast = arith::TruncFOp::create( + b, resultTy, result, op.getRoundingmodeAttr(), op.getFastmathAttr()); rewriter.replaceOp(op, resultCast); return success(); } diff --git a/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp b/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp index f2f93883eb2b7..777ff0ecaa314 100644 --- a/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp +++ b/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp @@ -305,18 +305,18 @@ static Value doCast(OpBuilder &builder, Location loc, Value src, Type dstType, if (isa(srcElemType) || isa(dstElemType)) { if (castKind == CastKind::Signed) - return builder.create(loc, dstType, src); - return builder.create(loc, dstType, src); + return arith::IndexCastOp::create(builder, loc, dstType, src); + return arith::IndexCastUIOp::create(builder, loc, dstType, src); } auto srcInt = cast(srcElemType); auto dstInt = cast(dstElemType); if (dstInt.getWidth() < srcInt.getWidth()) - return builder.create(loc, dstType, src); + return arith::TruncIOp::create(builder, loc, dstType, src); if (castKind == CastKind::Signed) - return builder.create(loc, dstType, src); - return builder.create(loc, dstType, src); + return arith::ExtSIOp::create(builder, loc, dstType, src); + return arith::ExtUIOp::create(builder, loc, dstType, src); } struct NarrowElementwise final : OpTraitRewritePattern { diff --git a/mlir/lib/Dialect/Arith/Transforms/ReifyValueBounds.cpp b/mlir/lib/Dialect/Arith/Transforms/ReifyValueBounds.cpp index 5fb7953f93700..4bdd1e6a54d69 100644 --- a/mlir/lib/Dialect/Arith/Transforms/ReifyValueBounds.cpp +++ b/mlir/lib/Dialect/Arith/Transforms/ReifyValueBounds.cpp @@ -23,8 +23,8 @@ static Value buildArithValue(OpBuilder &b, Location loc, AffineMap map, std::function buildExpr = [&](AffineExpr e) -> Value { switch (e.getKind()) { case AffineExprKind::Constant: - return b.create(loc, - cast(e).getValue()); + return ConstantIndexOp::create(b, loc, + cast(e).getValue()); case AffineExprKind::DimId: return operands[cast(e).getPosition()]; case AffineExprKind::SymbolId: @@ -32,28 +32,28 @@ static Value buildArithValue(OpBuilder &b, Location loc, AffineMap map, map.getNumDims()]; case AffineExprKind::Add: { auto binaryExpr = cast(e); - return b.create(loc, buildExpr(binaryExpr.getLHS()), - buildExpr(binaryExpr.getRHS())); + return AddIOp::create(b, loc, buildExpr(binaryExpr.getLHS()), + buildExpr(binaryExpr.getRHS())); } case AffineExprKind::Mul: { auto binaryExpr = cast(e); - return b.create(loc, buildExpr(binaryExpr.getLHS()), - buildExpr(binaryExpr.getRHS())); + return MulIOp::create(b, loc, buildExpr(binaryExpr.getLHS()), + buildExpr(binaryExpr.getRHS())); } case AffineExprKind::FloorDiv: { auto binaryExpr = cast(e); - return b.create(loc, buildExpr(binaryExpr.getLHS()), - buildExpr(binaryExpr.getRHS())); + return DivSIOp::create(b, loc, buildExpr(binaryExpr.getLHS()), + buildExpr(binaryExpr.getRHS())); } case AffineExprKind::CeilDiv: { auto binaryExpr = cast(e); - return b.create(loc, buildExpr(binaryExpr.getLHS()), - buildExpr(binaryExpr.getRHS())); + return CeilDivSIOp::create(b, loc, buildExpr(binaryExpr.getLHS()), + buildExpr(binaryExpr.getRHS())); } case AffineExprKind::Mod: { auto binaryExpr = cast(e); - return b.create(loc, buildExpr(binaryExpr.getLHS()), - buildExpr(binaryExpr.getRHS())); + return RemSIOp::create(b, loc, buildExpr(binaryExpr.getLHS()), + buildExpr(binaryExpr.getRHS())); } } llvm_unreachable("unsupported AffineExpr kind"); @@ -89,10 +89,10 @@ FailureOr mlir::arith::reifyValueBound( "expected dynamic dim"); if (isa(value.getType())) { // A tensor dimension is used: generate a tensor.dim. - operands.push_back(b.create(loc, value, *dim)); + operands.push_back(tensor::DimOp::create(b, loc, value, *dim)); } else if (isa(value.getType())) { // A memref dimension is used: generate a memref.dim. - operands.push_back(b.create(loc, value, *dim)); + operands.push_back(memref::DimOp::create(b, loc, value, *dim)); } else { llvm_unreachable("cannot generate DimOp for unsupported shaped type"); } diff --git a/mlir/lib/Dialect/Arith/Transforms/ShardingInterfaceImpl.cpp b/mlir/lib/Dialect/Arith/Transforms/ShardingInterfaceImpl.cpp index 3478adcb4a128..dd6efe6d6bc31 100644 --- a/mlir/lib/Dialect/Arith/Transforms/ShardingInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Arith/Transforms/ShardingInterfaceImpl.cpp @@ -83,7 +83,7 @@ struct ConstantShardingInterface cOp.getType(), getMesh(op, sharding.getMeshAttr(), symbolTable), sharding)); auto newValue = value.resizeSplat(newType); - auto newOp = builder.create(op->getLoc(), newType, newValue); + auto newOp = ConstantOp::create(builder, op->getLoc(), newType, newValue); spmdizationMap.map(op->getResult(0), newOp.getResult()); spmdizationMap.map(op, newOp.getOperation()); } else { diff --git a/mlir/lib/Dialect/Arith/Utils/Utils.cpp b/mlir/lib/Dialect/Arith/Utils/Utils.cpp index bdeeccbe0177a..b1fc9aa57c3ba 100644 --- a/mlir/lib/Dialect/Arith/Utils/Utils.cpp +++ b/mlir/lib/Dialect/Arith/Utils/Utils.cpp @@ -67,7 +67,7 @@ mlir::inferExpandShapeOutputShape(OpBuilder &b, Location loc, // dynamism. Value indexGroupSize = cast(inputShape[inputIndex]); Value indexGroupStaticSizesProduct = - b.create(loc, indexGroupStaticSizesProductInt); + arith::ConstantIndexOp::create(b, loc, indexGroupStaticSizesProductInt); Value dynamicDimSize = b.createOrFold( loc, indexGroupSize, indexGroupStaticSizesProduct); outputShapeValues.push_back(dynamicDimSize); @@ -104,8 +104,8 @@ Value mlir::getValueOrCreateConstantIntOp(OpBuilder &b, Location loc, if (auto value = dyn_cast_if_present(ofr)) return value; auto attr = cast(cast(ofr)); - return b.create( - loc, b.getIntegerAttr(attr.getType(), attr.getValue().getSExtValue())); + return arith::ConstantOp::create( + b, loc, b.getIntegerAttr(attr.getType(), attr.getValue().getSExtValue())); } Value mlir::getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, @@ -113,7 +113,7 @@ Value mlir::getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, if (auto value = dyn_cast_if_present(ofr)) return value; auto attr = cast(cast(ofr)); - return b.create(loc, attr.getValue().getSExtValue()); + return arith::ConstantIndexOp::create(b, loc, attr.getValue().getSExtValue()); } Value mlir::getValueOrCreateCastToIndexLike(OpBuilder &b, Location loc, @@ -124,7 +124,7 @@ Value mlir::getValueOrCreateCastToIndexLike(OpBuilder &b, Location loc, bool targetIsIndex = targetType.isIndex(); bool valueIsIndex = value.getType().isIndex(); if (targetIsIndex ^ valueIsIndex) - return b.create(loc, targetType, value); + return arith::IndexCastOp::create(b, loc, targetType, value); auto targetIntegerType = dyn_cast(targetType); auto valueIntegerType = dyn_cast(value.getType()); @@ -133,8 +133,8 @@ Value mlir::getValueOrCreateCastToIndexLike(OpBuilder &b, Location loc, assert(targetIntegerType.getSignedness() == valueIntegerType.getSignedness()); if (targetIntegerType.getWidth() > valueIntegerType.getWidth()) - return b.create(loc, targetIntegerType, value); - return b.create(loc, targetIntegerType, value); + return arith::ExtSIOp::create(b, loc, targetIntegerType, value); + return arith::TruncIOp::create(b, loc, targetIntegerType, value); } static Value convertScalarToIntDtype(ImplicitLocOpBuilder &b, Value operand, @@ -142,21 +142,21 @@ static Value convertScalarToIntDtype(ImplicitLocOpBuilder &b, Value operand, // If operand is floating point, cast directly to the int type. if (isa(operand.getType())) { if (isUnsigned) - return b.create(toType, operand); - return b.create(toType, operand); + return arith::FPToUIOp::create(b, toType, operand); + return arith::FPToSIOp::create(b, toType, operand); } // Cast index operands directly to the int type. if (operand.getType().isIndex()) - return b.create(toType, operand); + return arith::IndexCastOp::create(b, toType, operand); if (auto fromIntType = dyn_cast(operand.getType())) { // Either extend or truncate. if (toType.getWidth() > fromIntType.getWidth()) { if (isUnsigned) - return b.create(toType, operand); - return b.create(toType, operand); + return arith::ExtUIOp::create(b, toType, operand); + return arith::ExtSIOp::create(b, toType, operand); } if (toType.getWidth() < fromIntType.getWidth()) - return b.create(toType, operand); + return arith::TruncIOp::create(b, toType, operand); return operand; } @@ -169,14 +169,14 @@ static Value convertScalarToFpDtype(ImplicitLocOpBuilder &b, Value operand, // Note that it is unclear how to cast from BF16<->FP16. if (isa(operand.getType())) { if (isUnsigned) - return b.create(toType, operand); - return b.create(toType, operand); + return arith::UIToFPOp::create(b, toType, operand); + return arith::SIToFPOp::create(b, toType, operand); } if (auto fromFpTy = dyn_cast(operand.getType())) { if (toType.getWidth() > fromFpTy.getWidth()) - return b.create(toType, operand); + return arith::ExtFOp::create(b, toType, operand); if (toType.getWidth() < fromFpTy.getWidth()) - return b.create(toType, operand); + return arith::TruncFOp::create(b, toType, operand); return operand; } @@ -189,18 +189,18 @@ static Value convertScalarToComplexDtype(ImplicitLocOpBuilder &b, Value operand, if (auto fromComplexType = dyn_cast(operand.getType())) { if (isa(targetType.getElementType()) && isa(fromComplexType.getElementType())) { - Value real = b.create(operand); - Value imag = b.create(operand); + Value real = complex::ReOp::create(b, operand); + Value imag = complex::ImOp::create(b, operand); Type targetETy = targetType.getElementType(); if (targetType.getElementType().getIntOrFloatBitWidth() < fromComplexType.getElementType().getIntOrFloatBitWidth()) { - real = b.create(targetETy, real); - imag = b.create(targetETy, imag); + real = arith::TruncFOp::create(b, targetETy, real); + imag = arith::TruncFOp::create(b, targetETy, imag); } else { - real = b.create(targetETy, real); - imag = b.create(targetETy, imag); + real = arith::ExtFOp::create(b, targetETy, real); + imag = arith::ExtFOp::create(b, targetETy, imag); } - return b.create(targetType, real, imag); + return complex::CreateOp::create(b, targetType, real, imag); } } @@ -209,27 +209,27 @@ static Value convertScalarToComplexDtype(ImplicitLocOpBuilder &b, Value operand, auto toBitwidth = toFpTy.getIntOrFloatBitWidth(); Value from = operand; if (from.getType().getIntOrFloatBitWidth() < toBitwidth) { - from = b.create(toFpTy, from); + from = arith::ExtFOp::create(b, toFpTy, from); } if (from.getType().getIntOrFloatBitWidth() > toBitwidth) { - from = b.create(toFpTy, from); + from = arith::TruncFOp::create(b, toFpTy, from); } - Value zero = b.create( - toFpTy, mlir::APFloat(toFpTy.getFloatSemantics(), 0)); - return b.create(targetType, from, zero); + Value zero = mlir::arith::ConstantFloatOp::create( + b, toFpTy, mlir::APFloat(toFpTy.getFloatSemantics(), 0)); + return complex::CreateOp::create(b, targetType, from, zero); } if (isa(operand.getType())) { FloatType toFpTy = cast(targetType.getElementType()); Value from = operand; if (isUnsigned) { - from = b.create(toFpTy, from); + from = arith::UIToFPOp::create(b, toFpTy, from); } else { - from = b.create(toFpTy, from); + from = arith::SIToFPOp::create(b, toFpTy, from); } - Value zero = b.create( - toFpTy, mlir::APFloat(toFpTy.getFloatSemantics(), 0)); - return b.create(targetType, from, zero); + Value zero = mlir::arith::ConstantFloatOp::create( + b, toFpTy, mlir::APFloat(toFpTy.getFloatSemantics(), 0)); + return complex::CreateOp::create(b, targetType, from, zero); } return {}; @@ -277,7 +277,7 @@ Value mlir::createScalarOrSplatConstant(OpBuilder &builder, Location loc, attr = SplatElementsAttr::get(vecTy, value); } - return builder.create(loc, attr); + return arith::ConstantOp::create(builder, loc, attr); } Value mlir::createScalarOrSplatConstant(OpBuilder &builder, Location loc, @@ -309,35 +309,35 @@ Type mlir::getType(OpFoldResult ofr) { } Value ArithBuilder::_and(Value lhs, Value rhs) { - return b.create(loc, lhs, rhs); + return arith::AndIOp::create(b, loc, lhs, rhs); } Value ArithBuilder::add(Value lhs, Value rhs) { if (isa(lhs.getType())) - return b.create(loc, lhs, rhs); - return b.create(loc, lhs, rhs, ovf); + return arith::AddFOp::create(b, loc, lhs, rhs); + return arith::AddIOp::create(b, loc, lhs, rhs, ovf); } Value ArithBuilder::sub(Value lhs, Value rhs) { if (isa(lhs.getType())) - return b.create(loc, lhs, rhs); - return b.create(loc, lhs, rhs, ovf); + return arith::SubFOp::create(b, loc, lhs, rhs); + return arith::SubIOp::create(b, loc, lhs, rhs, ovf); } Value ArithBuilder::mul(Value lhs, Value rhs) { if (isa(lhs.getType())) - return b.create(loc, lhs, rhs); - return b.create(loc, lhs, rhs, ovf); + return arith::MulFOp::create(b, loc, lhs, rhs); + return arith::MulIOp::create(b, loc, lhs, rhs, ovf); } Value ArithBuilder::sgt(Value lhs, Value rhs) { if (isa(lhs.getType())) - return b.create(loc, arith::CmpFPredicate::OGT, lhs, rhs); - return b.create(loc, arith::CmpIPredicate::sgt, lhs, rhs); + return arith::CmpFOp::create(b, loc, arith::CmpFPredicate::OGT, lhs, rhs); + return arith::CmpIOp::create(b, loc, arith::CmpIPredicate::sgt, lhs, rhs); } Value ArithBuilder::slt(Value lhs, Value rhs) { if (isa(lhs.getType())) - return b.create(loc, arith::CmpFPredicate::OLT, lhs, rhs); - return b.create(loc, arith::CmpIPredicate::slt, lhs, rhs); + return arith::CmpFOp::create(b, loc, arith::CmpFPredicate::OLT, lhs, rhs); + return arith::CmpIOp::create(b, loc, arith::CmpIPredicate::slt, lhs, rhs); } Value ArithBuilder::select(Value cmp, Value lhs, Value rhs) { - return b.create(loc, cmp, lhs, rhs); + return arith::SelectOp::create(b, loc, cmp, lhs, rhs); } namespace mlir::arith { @@ -348,8 +348,8 @@ Value createProduct(OpBuilder &builder, Location loc, ArrayRef values) { Value createProduct(OpBuilder &builder, Location loc, ArrayRef values, Type resultType) { - Value one = builder.create(loc, resultType, - builder.getOneAttr(resultType)); + Value one = ConstantOp::create(builder, loc, resultType, + builder.getOneAttr(resultType)); ArithBuilder arithBuilder(builder, loc); return std::accumulate( values.begin(), values.end(), one,