From 24edc8dd17e99d4e6fdffab87261db80e1ef8c77 Mon Sep 17 00:00:00 2001 From: max Date: Mon, 21 Jul 2025 15:26:00 -0400 Subject: [PATCH] [mlir][NFC] update `Conversion` create APIs (4/n) (#149687) See https://github.com/llvm/llvm-project/pull/147168 for more info. --- .../AMDGPUToROCDL/AMDGPUToROCDL.cpp | 278 +++++++++--------- .../AffineToStandard/AffineToStandard.cpp | 44 +-- .../ArithToAMDGPU/ArithToAMDGPU.cpp | 213 +++++++------- .../ArithToArmSME/ArithToArmSME.cpp | 8 +- .../Conversion/ArithToEmitC/ArithToEmitC.cpp | 90 +++--- .../Conversion/ArithToLLVM/ArithToLLVM.cpp | 38 +-- .../Conversion/ArithToSPIRV/ArithToSPIRV.cpp | 103 +++---- .../ArmNeon2dToIntr/ArmNeon2dToIntr.cpp | 8 +- .../Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp | 213 +++++++------- .../Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp | 112 +++---- .../Conversion/AsyncToLLVM/AsyncToLLVM.cpp | 106 +++---- .../BufferizationToMemRef.cpp | 27 +- 12 files changed, 631 insertions(+), 609 deletions(-) diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp index ef35ee208f002..fe3dc91328879 100644 --- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp +++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp @@ -50,20 +50,20 @@ static Value convertUnsignedToI32(ConversionPatternRewriter &rewriter, if (i32 == valTy) return val; return valTy.getWidth() > 32 - ? Value(rewriter.create(loc, i32, val)) - : Value(rewriter.create(loc, i32, val)); + ? Value(LLVM::TruncOp::create(rewriter, loc, i32, val)) + : Value(LLVM::ZExtOp::create(rewriter, loc, i32, val)); } static Value createI32Constant(ConversionPatternRewriter &rewriter, Location loc, int32_t value) { Type i32 = rewriter.getI32Type(); - return rewriter.create(loc, i32, value); + return LLVM::ConstantOp::create(rewriter, loc, i32, value); } static Value createI1Constant(ConversionPatternRewriter &rewriter, Location loc, bool value) { Type llvmI1 = rewriter.getI1Type(); - return rewriter.create(loc, llvmI1, value); + return LLVM::ConstantOp::create(rewriter, loc, llvmI1, value); } /// Returns the linear index used to access an element in the memref. @@ -78,11 +78,11 @@ static Value getLinearIndexI32(ConversionPatternRewriter &rewriter, ShapedType::isDynamic(stride) ? convertUnsignedToI32(rewriter, loc, memRefDescriptor.stride(rewriter, loc, i)) - : rewriter.create(loc, i32, stride); - increment = rewriter.create(loc, increment, strideValue); + : LLVM::ConstantOp::create(rewriter, loc, i32, stride); + increment = LLVM::MulOp::create(rewriter, loc, increment, strideValue); } - index = - index ? rewriter.create(loc, index, increment) : increment; + index = index ? LLVM::AddOp::create(rewriter, loc, index, increment) + : increment; } return index ? index : createI32Constant(rewriter, loc, 0); } @@ -110,14 +110,14 @@ static Value getNumRecords(ConversionPatternRewriter &rewriter, Location loc, for (uint32_t i = 0, e = memrefType.getRank(); i < e; ++i) { Value size = memrefDescriptor.size(rewriter, loc, i); Value stride = memrefDescriptor.stride(rewriter, loc, i); - Value maxThisDim = rewriter.create(loc, size, stride); + Value maxThisDim = LLVM::MulOp::create(rewriter, loc, size, stride); maxIndex = maxIndex - ? rewriter.create(loc, maxIndex, maxThisDim) + ? LLVM::UMaxOp::create(rewriter, loc, maxIndex, maxThisDim) : maxThisDim; } Value maxIndexI32 = convertUnsignedToI32(rewriter, loc, maxIndex); Value byteWidthConst = createI32Constant(rewriter, loc, elementByteWidth); - return rewriter.create(loc, maxIndexI32, byteWidthConst); + return LLVM::MulOp::create(rewriter, loc, maxIndexI32, byteWidthConst); } static Value makeBufferRsrc(ConversionPatternRewriter &rewriter, Location loc, @@ -132,14 +132,14 @@ static Value makeBufferRsrc(ConversionPatternRewriter &rewriter, Location loc, Value stride; if (chipset.majorVersion == 9 && chipset >= kGfx942 && cacheSwizzleStride) { Value cacheStrideZext = - rewriter.create(loc, i16, cacheSwizzleStride); - Value swizzleBit = rewriter.create( - loc, i16, rewriter.getI16IntegerAttr(1 << 14)); - stride = rewriter.create(loc, cacheStrideZext, swizzleBit, - /*isDisjoint=*/true); + LLVM::ZExtOp::create(rewriter, loc, i16, cacheSwizzleStride); + Value swizzleBit = LLVM::ConstantOp::create( + rewriter, loc, i16, rewriter.getI16IntegerAttr(1 << 14)); + stride = LLVM::OrOp::create(rewriter, loc, cacheStrideZext, swizzleBit, + /*isDisjoint=*/true); } else { - stride = rewriter.create(loc, i16, - rewriter.getI16IntegerAttr(0)); + stride = LLVM::ConstantOp::create(rewriter, loc, i16, + rewriter.getI16IntegerAttr(0)); } // Get the number of elements. // Flag word: @@ -209,20 +209,21 @@ struct FatRawBufferCastLowering : descriptor.alignedPtr(rewriter, loc); Value offset = adaptor.getResetOffset() - ? rewriter.create( - loc, getIndexType(), rewriter.getIndexAttr(0)) + ? LLVM::ConstantOp::create(rewriter, loc, getIndexType(), + rewriter.getIndexAttr(0)) : descriptor.offset(rewriter, loc); bool hasSizes = memrefType.getRank() > 0; // No need to unpack() and pack() all the individual sizes and strides, // so we'll just extract the arrays. - Value sizes = hasSizes ? rewriter.create( - loc, descriptor, kSizePosInMemRefDescriptor) - : Value{}; - Value strides = hasSizes - ? rewriter.create( - loc, descriptor, kStridePosInMemRefDescriptor) - : Value{}; + Value sizes = hasSizes + ? LLVM::ExtractValueOp::create(rewriter, loc, descriptor, + kSizePosInMemRefDescriptor) + : Value{}; + Value strides = + hasSizes ? LLVM::ExtractValueOp::create(rewriter, loc, descriptor, + kStridePosInMemRefDescriptor) + : Value{}; Value fatPtr = makeBufferRsrc( rewriter, loc, basePointer, numRecords, adaptor.getBoundsCheck(), @@ -231,17 +232,17 @@ struct FatRawBufferCastLowering Value result = MemRefDescriptor::poison( rewriter, loc, getTypeConverter()->convertType(op.getResult().getType())); - result = rewriter.create( - loc, result, fatPtr, kAllocatedPtrPosInMemRefDescriptor); - result = rewriter.create( - loc, result, fatPtr, kAlignedPtrPosInMemRefDescriptor); - result = rewriter.create(loc, result, offset, - kOffsetPosInMemRefDescriptor); + result = LLVM::InsertValueOp::create(rewriter, loc, result, fatPtr, + kAllocatedPtrPosInMemRefDescriptor); + result = LLVM::InsertValueOp::create(rewriter, loc, result, fatPtr, + kAlignedPtrPosInMemRefDescriptor); + result = LLVM::InsertValueOp::create(rewriter, loc, result, offset, + kOffsetPosInMemRefDescriptor); if (hasSizes) { - result = rewriter.create(loc, result, sizes, - kSizePosInMemRefDescriptor); - result = rewriter.create( - loc, result, strides, kStridePosInMemRefDescriptor); + result = LLVM::InsertValueOp::create(rewriter, loc, result, sizes, + kSizePosInMemRefDescriptor); + result = LLVM::InsertValueOp::create(rewriter, loc, result, strides, + kStridePosInMemRefDescriptor); } rewriter.replaceOp(op, result); return success(); @@ -342,8 +343,8 @@ struct RawBufferOpLowering : public ConvertOpToLLVMPattern { SmallVector args; if (storeData) { if (llvmBufferValType != llvmWantedDataType) { - Value castForStore = - rewriter.create(loc, llvmBufferValType, storeData); + Value castForStore = LLVM::BitcastOp::create( + rewriter, loc, llvmBufferValType, storeData); args.push_back(castForStore); } else { args.push_back(storeData); @@ -352,8 +353,8 @@ struct RawBufferOpLowering : public ConvertOpToLLVMPattern { if (atomicCmpData) { if (llvmBufferValType != llvmWantedDataType) { - Value castForCmp = rewriter.create( - loc, llvmBufferValType, atomicCmpData); + Value castForCmp = LLVM::BitcastOp::create( + rewriter, loc, llvmBufferValType, atomicCmpData); args.push_back(castForCmp); } else { args.push_back(atomicCmpData); @@ -382,18 +383,18 @@ struct RawBufferOpLowering : public ConvertOpToLLVMPattern { if (std::optional indexOffset = adaptor.getIndexOffset(); indexOffset && *indexOffset > 0) { Value extraOffsetConst = createI32Constant(rewriter, loc, *indexOffset); - voffset = - voffset ? rewriter.create(loc, voffset, extraOffsetConst) - : extraOffsetConst; + voffset = voffset ? LLVM::AddOp::create(rewriter, loc, voffset, + extraOffsetConst) + : extraOffsetConst; } - voffset = rewriter.create(loc, voffset, byteWidthConst); + voffset = LLVM::MulOp::create(rewriter, loc, voffset, byteWidthConst); args.push_back(voffset); // SGPR offset. Value sgprOffset = adaptor.getSgprOffset(); if (!sgprOffset) sgprOffset = createI32Constant(rewriter, loc, 0); - sgprOffset = rewriter.create(loc, sgprOffset, byteWidthConst); + sgprOffset = LLVM::MulOp::create(rewriter, loc, sgprOffset, byteWidthConst); args.push_back(sgprOffset); // bit 0: GLC = 0 (atomics drop value, less coherency) @@ -403,13 +404,13 @@ struct RawBufferOpLowering : public ConvertOpToLLVMPattern { llvm::SmallVector resultTypes(gpuOp->getNumResults(), llvmBufferValType); - Operation *lowered = rewriter.create(loc, resultTypes, args, - ArrayRef()); + Operation *lowered = Intrinsic::create(rewriter, loc, resultTypes, args, + ArrayRef()); if (lowered->getNumResults() == 1) { Value replacement = lowered->getResult(0); if (llvmBufferValType != llvmWantedDataType) { - replacement = rewriter.create(loc, llvmWantedDataType, - replacement); + replacement = LLVM::BitcastOp::create(rewriter, loc, llvmWantedDataType, + replacement); } rewriter.replaceOp(gpuOp, replacement); } else { @@ -465,12 +466,12 @@ struct LDSBarrierOpLowering : public ConvertOpToLLVMPattern { << chipset.majorVersion; Location loc = op->getLoc(); - rewriter.create(loc, ldsOnlyBits); + ROCDL::SWaitcntOp::create(rewriter, loc, ldsOnlyBits); rewriter.replaceOpWithNewOp(op); } else { Location loc = op->getLoc(); - rewriter.create(loc, 0); - rewriter.create(loc, -1); + ROCDL::WaitDscntOp::create(rewriter, loc, 0); + ROCDL::BarrierSignalOp::create(rewriter, loc, -1); rewriter.replaceOpWithNewOp(op, -1); } @@ -516,19 +517,21 @@ static Value convertMFMAVectorOperand(ConversionPatternRewriter &rewriter, Type inputType = input.getType(); if (auto vectorType = dyn_cast(inputType)) { if (vectorType.getElementType().isBF16() && !allowBf16) - return rewriter.create( - loc, vectorType.clone(rewriter.getI16Type()), input); + return LLVM::BitcastOp::create( + rewriter, loc, vectorType.clone(rewriter.getI16Type()), input); if (vectorType.getElementType().isInteger(8) && vectorType.getNumElements() <= 8) - return rewriter.create( - loc, rewriter.getIntegerType(vectorType.getNumElements() * 8), input); + return LLVM::BitcastOp::create( + rewriter, loc, + rewriter.getIntegerType(vectorType.getNumElements() * 8), input); if (isa(vectorType.getElementType()) && vectorType.getElementTypeBitWidth() <= 8) { int64_t numWords = llvm::divideCeil( vectorType.getNumElements() * vectorType.getElementTypeBitWidth(), 32); - return rewriter.create( - loc, VectorType::get(numWords, rewriter.getI32Type()), input); + return LLVM::BitcastOp::create( + rewriter, loc, VectorType::get(numWords, rewriter.getI32Type()), + input); } } return input; @@ -549,8 +552,8 @@ static Value castMFMAScaleOperand(ConversionPatternRewriter &rewriter, Type inputType = input.getType(); Type outputType = rewriter.getI32Type(); if (auto intType = dyn_cast(inputType)) - return rewriter.create(loc, outputType, input); - return rewriter.create(loc, outputType, input); + return LLVM::ZExtOp::create(rewriter, loc, outputType, input); + return LLVM::BitcastOp::create(rewriter, loc, outputType, input); } /// Push an input operand. If it is a float type, nothing to do. If it is @@ -576,8 +579,8 @@ static void wmmaPushInputOperand(ConversionPatternRewriter &rewriter, Type elemType = vectorType.getElementType(); if (elemType.isBF16()) - llvmInput = rewriter.create( - loc, vectorType.clone(rewriter.getI16Type()), llvmInput); + llvmInput = LLVM::BitcastOp::create( + rewriter, loc, vectorType.clone(rewriter.getI16Type()), llvmInput); if (elemType.getIntOrFloatBitWidth() > 8) { operands.push_back(llvmInput); return; @@ -613,7 +616,7 @@ static void wmmaPushInputOperand(ConversionPatternRewriter &rewriter, // (256 / 64) * 4 = 16 bits of input (on gfx12+) but take i32 arguments. // Add in the zeros here. if (numBits < 32) - castInput = rewriter.create(loc, i32, castInput); + castInput = LLVM::ZExtOp::create(rewriter, loc, i32, castInput); operands.push_back(castInput); } @@ -633,8 +636,8 @@ static void wmmaPushOutputOperand(ConversionPatternRewriter &rewriter, auto vectorType = dyn_cast(inputType); Type elemType = vectorType.getElementType(); if (elemType.isBF16()) - output = rewriter.create( - loc, vectorType.clone(rewriter.getI16Type()), output); + output = LLVM::BitcastOp::create( + rewriter, loc, vectorType.clone(rewriter.getI16Type()), output); operands.push_back(output); if (elemType.isF16() || elemType.isBF16() || elemType.isInteger(16)) { operands.push_back(createI1Constant(rewriter, loc, subwordOffset)); @@ -992,7 +995,7 @@ struct MFMAOpLowering : public ConvertOpToLLVMPattern { }; Value lowered = rewriter.create(loweredOp)->getResult(0); if (outType != intrinsicOutType) - lowered = rewriter.create(loc, outType, lowered); + lowered = LLVM::BitcastOp::create(rewriter, loc, outType, lowered); rewriter.replaceOp(op, lowered); return success(); } @@ -1092,8 +1095,8 @@ struct WMMAOpLowering : public ConvertOpToLLVMPattern { Operation *maybeCastBack = lowered; if (rawOutType != outType) - maybeCastBack = - rewriter.create(loc, outType, lowered->getResult(0)); + maybeCastBack = LLVM::BitcastOp::create(rewriter, loc, outType, + lowered->getResult(0)); rewriter.replaceOp(op, maybeCastBack->getResults()); return success(); @@ -1143,22 +1146,22 @@ struct TransposeLoadOpLowering switch (elementTypeSize) { case 4: { assert(numElements == 16); - auto rocdlOp = - rewriter.create(loc, rocdlResultType, srcPtr); + auto rocdlOp = ROCDL::ds_read_tr4_b64::create(rewriter, loc, + rocdlResultType, srcPtr); rewriter.replaceOpWithNewOp(op, llvmResultType, rocdlOp); break; } case 6: { assert(numElements == 16); - auto rocdlOp = - rewriter.create(loc, rocdlResultType, srcPtr); + auto rocdlOp = ROCDL::ds_read_tr6_b96::create(rewriter, loc, + rocdlResultType, srcPtr); rewriter.replaceOpWithNewOp(op, llvmResultType, rocdlOp); break; } case 8: { assert(numElements == 8); - auto rocdlOp = - rewriter.create(loc, rocdlResultType, srcPtr); + auto rocdlOp = ROCDL::ds_read_tr8_b64::create(rewriter, loc, + rocdlResultType, srcPtr); rewriter.replaceOpWithNewOp(op, llvmResultType, rocdlOp); break; } @@ -1316,21 +1319,21 @@ LogicalResult ExtPackedFp8OpLowering::matchAndRewrite( Type sourceElemType = getElementTypeOrSelf(op.getSource()); // Extend to a v4i8 if (!sourceVecType || sourceVecType.getNumElements() < 4) { - Value longVec = rewriter.create(loc, v4i8); + Value longVec = LLVM::UndefOp::create(rewriter, loc, v4i8); if (!sourceVecType) { - longVec = rewriter.create( - loc, longVec, source, createI32Constant(rewriter, loc, 0)); + longVec = LLVM::InsertElementOp::create( + rewriter, loc, longVec, source, createI32Constant(rewriter, loc, 0)); } else { for (int32_t i = 0, e = sourceVecType.getNumElements(); i < e; ++i) { Value idx = createI32Constant(rewriter, loc, i); - Value elem = rewriter.create(loc, source, idx); + Value elem = LLVM::ExtractElementOp::create(rewriter, loc, source, idx); longVec = - rewriter.create(loc, longVec, elem, idx); + LLVM::InsertElementOp::create(rewriter, loc, longVec, elem, idx); } } source = longVec; } - Value i32Source = rewriter.create(loc, i32, source); + Value i32Source = LLVM::BitcastOp::create(rewriter, loc, i32, source); if (resultVecType) { if (typeIsExpectedBf8ForChipset(chipset, sourceElemType)) { rewriter.replaceOpWithNewOp(op, f32, i32Source, @@ -1382,21 +1385,21 @@ LogicalResult ScaledExtPackedOpLowering::matchAndRewrite( // Extend to a packedVectorType if (sourceVecType.getNumElements() < packedVecType.getNumElements()) { - Value longVec = rewriter.create(loc, packedVecType); + Value longVec = LLVM::ZeroOp::create(rewriter, loc, packedVecType); if (!sourceVecType) { - longVec = rewriter.create( - loc, longVec, source, createI32Constant(rewriter, loc, 0)); + longVec = LLVM::InsertElementOp::create( + rewriter, loc, longVec, source, createI32Constant(rewriter, loc, 0)); } else { for (int32_t i = 0, e = sourceVecType.getNumElements(); i < e; ++i) { Value idx = createI32Constant(rewriter, loc, i); - Value elem = rewriter.create(loc, source, idx); + Value elem = LLVM::ExtractElementOp::create(rewriter, loc, source, idx); longVec = - rewriter.create(loc, longVec, elem, idx); + LLVM::InsertElementOp::create(rewriter, loc, longVec, elem, idx); } } source = longVec; } - Value i32Source = rewriter.create(loc, i32, source); + Value i32Source = LLVM::BitcastOp::create(rewriter, loc, i32, source); if (isa(sourceElemType) && destElemType.isF32()) rewriter.replaceOpWithNewOp( @@ -1454,54 +1457,57 @@ LogicalResult PackedScaledTruncOpLowering::matchAndRewrite( Value scale = adaptor.getScale(); Value existing = adaptor.getExisting(); if (existing) - existing = rewriter.create(loc, intResultType, existing); + existing = LLVM::BitcastOp::create(rewriter, loc, intResultType, existing); else - existing = rewriter.create(loc, intResultType); + existing = LLVM::ZeroOp::create(rewriter, loc, intResultType); if (sourceVecType.getNumElements() < 2) { Value c0 = createI32Constant(rewriter, loc, 0); - Value elem0 = rewriter.create(loc, source, c0); + Value elem0 = LLVM::ExtractElementOp::create(rewriter, loc, source, c0); VectorType v2 = VectorType::get(2, sourceElemType); - source = rewriter.create(loc, v2); - source = rewriter.create(loc, source, elem0, c0); + source = LLVM::ZeroOp::create(rewriter, loc, v2); + source = LLVM::InsertElementOp::create(rewriter, loc, source, elem0, c0); } Value sourceA, sourceB; if (sourceElemType.isF32()) { Value c0 = createI32Constant(rewriter, loc, 0); Value c1 = createI32Constant(rewriter, loc, 1); - sourceA = rewriter.create(loc, source, c0); - sourceB = rewriter.create(loc, source, c1); + sourceA = LLVM::ExtractElementOp::create(rewriter, loc, source, c0); + sourceB = LLVM::ExtractElementOp::create(rewriter, loc, source, c1); } Value result; if (sourceElemType.isF32() && isa(resultElemType)) - result = rewriter.create( - loc, intResultType, existing, sourceA, sourceB, scale, op.getIndex()); + result = ROCDL::CvtScaleF32PkBf8F32Op::create(rewriter, loc, intResultType, + existing, sourceA, sourceB, + scale, op.getIndex()); else if (sourceElemType.isF16() && isa(resultElemType)) - result = rewriter.create( - loc, intResultType, existing, source, scale, op.getIndex()); + result = ROCDL::CvtScaleF32PkBf8F16Op::create( + rewriter, loc, intResultType, existing, source, scale, op.getIndex()); else if (sourceElemType.isBF16() && isa(resultElemType)) - result = rewriter.create( - loc, intResultType, existing, source, scale, op.getIndex()); + result = ROCDL::CvtScaleF32PkBf8Bf16Op::create( + rewriter, loc, intResultType, existing, source, scale, op.getIndex()); else if (sourceElemType.isF32() && isa(resultElemType)) - result = rewriter.create( - loc, intResultType, existing, sourceA, sourceB, scale, op.getIndex()); + result = ROCDL::CvtScaleF32PkFp8F32Op::create(rewriter, loc, intResultType, + existing, sourceA, sourceB, + scale, op.getIndex()); else if (sourceElemType.isF16() && isa(resultElemType)) - result = rewriter.create( - loc, intResultType, existing, source, scale, op.getIndex()); + result = ROCDL::CvtScaleF32PkFp8F16Op::create( + rewriter, loc, intResultType, existing, source, scale, op.getIndex()); else if (sourceElemType.isBF16() && isa(resultElemType)) - result = rewriter.create( - loc, intResultType, existing, source, scale, op.getIndex()); + result = ROCDL::CvtScaleF32PkFp8Bf16Op::create( + rewriter, loc, intResultType, existing, source, scale, op.getIndex()); else if (sourceElemType.isF32() && isa(resultElemType)) - result = rewriter.create( - loc, intResultType, existing, sourceA, sourceB, scale, op.getIndex()); + result = ROCDL::CvtScaleF32PkFp4F32Op::create(rewriter, loc, intResultType, + existing, sourceA, sourceB, + scale, op.getIndex()); else if (sourceElemType.isF16() && isa(resultElemType)) - result = rewriter.create( - loc, intResultType, existing, source, scale, op.getIndex()); + result = ROCDL::CvtScaleF32PkFp4F16Op::create( + rewriter, loc, intResultType, existing, source, scale, op.getIndex()); else if (sourceElemType.isBF16() && isa(resultElemType)) - result = rewriter.create( - loc, intResultType, existing, source, scale, op.getIndex()); + result = ROCDL::CvtScaleF32PkFp4Bf16Op::create( + rewriter, loc, intResultType, existing, source, scale, op.getIndex()); else return failure(); @@ -1526,20 +1532,20 @@ LogicalResult PackedTrunc2xFp8OpLowering::matchAndRewrite( Value sourceA = adaptor.getSourceA(); Value sourceB = adaptor.getSourceB(); if (!sourceB) - sourceB = rewriter.create(loc, sourceA.getType()); + sourceB = LLVM::UndefOp::create(rewriter, loc, sourceA.getType()); Value existing = adaptor.getExisting(); if (existing) - existing = rewriter.create(loc, i32, existing); + existing = LLVM::BitcastOp::create(rewriter, loc, i32, existing); else - existing = rewriter.create(loc, i32); + existing = LLVM::UndefOp::create(rewriter, loc, i32); Value result; if (typeIsExpectedBf8ForChipset(chipset, resultElemType)) - result = rewriter.create(loc, i32, sourceA, sourceB, - existing, op.getWordIndex()); + result = ROCDL::CvtPkBf8F32Op::create(rewriter, loc, i32, sourceA, sourceB, + existing, op.getWordIndex()); else if (typeIsExpectedFp8ForChipset(chipset, resultElemType)) - result = rewriter.create(loc, i32, sourceA, sourceB, - existing, op.getWordIndex()); + result = ROCDL::CvtPkFp8F32Op::create(rewriter, loc, i32, sourceA, sourceB, + existing, op.getWordIndex()); result = rewriter.replaceOpWithNewOp( op, getTypeConverter()->convertType(resultType), result); @@ -1563,17 +1569,17 @@ LogicalResult PackedStochRoundFp8OpLowering::matchAndRewrite( Value stoch = adaptor.getStochiasticParam(); Value existing = adaptor.getExisting(); if (existing) - existing = rewriter.create(loc, i32, existing); + existing = LLVM::BitcastOp::create(rewriter, loc, i32, existing); else - existing = rewriter.create(loc, i32); + existing = LLVM::UndefOp::create(rewriter, loc, i32); Value result; if (typeIsExpectedBf8ForChipset(chipset, resultElemType)) - result = rewriter.create( - loc, i32, source, stoch, existing, op.getStoreIndex()); + result = ROCDL::CvtSrBf8F32Op::create(rewriter, loc, i32, source, stoch, + existing, op.getStoreIndex()); else if (typeIsExpectedFp8ForChipset(chipset, resultElemType)) - result = rewriter.create( - loc, i32, source, stoch, existing, op.getStoreIndex()); + result = ROCDL::CvtSrFp8F32Op::create(rewriter, loc, i32, source, stoch, + existing, op.getStoreIndex()); result = rewriter.replaceOpWithNewOp( op, getTypeConverter()->convertType(resultType), result); @@ -1617,14 +1623,15 @@ struct AMDGPUDPPLowering : public ConvertOpToLLVMPattern { if (operandType.getIntOrFloatBitWidth() <= 16) { if (llvm::isa(operandType)) { operand = - rewriter.create(loc, llvmSrcIntType, operand); + LLVM::BitcastOp::create(rewriter, loc, llvmSrcIntType, operand); } auto llvmVecType = typeConverter->convertType(mlir::VectorType::get( 32 / operandType.getIntOrFloatBitWidth(), llvmSrcIntType)); - Value undefVec = rewriter.create(loc, llvmVecType); - operand = rewriter.create( - loc, undefVec, operand, createI32Constant(rewriter, loc, 0)); - operand = rewriter.create(loc, llvmType, operand); + Value undefVec = LLVM::UndefOp::create(rewriter, loc, llvmVecType); + operand = + LLVM::InsertElementOp::create(rewriter, loc, undefVec, operand, + createI32Constant(rewriter, loc, 0)); + operand = LLVM::BitcastOp::create(rewriter, loc, llvmType, operand); } return operand; }; @@ -1711,14 +1718,15 @@ struct AMDGPUDPPLowering : public ConvertOpToLLVMPattern { bool boundCtrl = DppOp->getAttrOfType("bound_ctrl").getValue(); // create a ROCDL_DPPMovOp instruction with the appropriate attributes - auto dppMovOp = rewriter.create( - loc, llvmType, old, src, DppCtrl, rowMask, bankMask, boundCtrl); + auto dppMovOp = + ROCDL::DPPUpdateOp::create(rewriter, loc, llvmType, old, src, DppCtrl, + rowMask, bankMask, boundCtrl); Value result = dppMovOp.getRes(); if (srcType.getIntOrFloatBitWidth() < 32) { - result = rewriter.create(loc, llvmSrcIntType, result); + result = LLVM::TruncOp::create(rewriter, loc, llvmSrcIntType, result); if (!llvm::isa(srcType)) { - result = rewriter.create(loc, srcType, result); + result = LLVM::BitcastOp::create(rewriter, loc, srcType, result); } } @@ -1752,7 +1760,7 @@ struct AMDGPUSwizzleBitModeLowering SmallVector swizzled; for (Value v : decomposed) { Value res = - rewriter.create(loc, v.getType(), v, maskValue); + ROCDL::DsSwizzleOp::create(rewriter, loc, v.getType(), v, maskValue); swizzled.emplace_back(res); } diff --git a/mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp b/mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp index 3b143ca1ef9ce..3b148f9021666 100644 --- a/mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp +++ b/mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp @@ -50,9 +50,9 @@ static Value buildMinMaxReductionSeq(Location loc, Value value = *valueIt++; for (; valueIt != values.end(); ++valueIt) { if (predicate == arith::CmpIPredicate::sgt) - value = builder.create(loc, value, *valueIt); + value = arith::MaxSIOp::create(builder, loc, value, *valueIt); else - value = builder.create(loc, value, *valueIt); + value = arith::MinSIOp::create(builder, loc, value, *valueIt); } return value; @@ -154,9 +154,9 @@ class AffineForLowering : public OpRewritePattern { Value lowerBound = lowerAffineLowerBound(op, rewriter); Value upperBound = lowerAffineUpperBound(op, rewriter); Value step = - rewriter.create(loc, op.getStepAsInt()); - auto scfForOp = rewriter.create(loc, lowerBound, upperBound, - step, op.getInits()); + arith::ConstantIndexOp::create(rewriter, loc, op.getStepAsInt()); + auto scfForOp = scf::ForOp::create(rewriter, loc, lowerBound, upperBound, + step, op.getInits()); rewriter.eraseBlock(scfForOp.getBody()); rewriter.inlineRegionBefore(op.getRegion(), scfForOp.getRegion(), scfForOp.getRegion().end()); @@ -197,7 +197,7 @@ class AffineParallelLowering : public OpRewritePattern { } steps.reserve(op.getSteps().size()); for (int64_t step : op.getSteps()) - steps.push_back(rewriter.create(loc, step)); + steps.push_back(arith::ConstantIndexOp::create(rewriter, loc, step)); // Get the terminator op. auto affineParOpTerminator = @@ -205,9 +205,9 @@ class AffineParallelLowering : public OpRewritePattern { scf::ParallelOp parOp; if (op.getResults().empty()) { // Case with no reduction operations/return values. - parOp = rewriter.create(loc, lowerBoundTuple, - upperBoundTuple, steps, - /*bodyBuilderFn=*/nullptr); + parOp = scf::ParallelOp::create(rewriter, loc, lowerBoundTuple, + upperBoundTuple, steps, + /*bodyBuilderFn=*/nullptr); rewriter.eraseBlock(parOp.getBody()); rewriter.inlineRegionBefore(op.getRegion(), parOp.getRegion(), parOp.getRegion().end()); @@ -233,9 +233,9 @@ class AffineParallelLowering : public OpRewritePattern { identityVals.push_back( arith::getIdentityValue(reductionOpValue, resultType, rewriter, loc)); } - parOp = rewriter.create( - loc, lowerBoundTuple, upperBoundTuple, steps, identityVals, - /*bodyBuilderFn=*/nullptr); + parOp = scf::ParallelOp::create(rewriter, loc, lowerBoundTuple, + upperBoundTuple, steps, identityVals, + /*bodyBuilderFn=*/nullptr); // Copy the body of the affine.parallel op. rewriter.eraseBlock(parOp.getBody()); @@ -261,7 +261,7 @@ class AffineParallelLowering : public OpRewritePattern { Value reductionResult = arith::getReductionOp( reductionOpValue, rewriter, loc, reductionBody.getArgument(0), reductionBody.getArgument(1)); - rewriter.create(loc, reductionResult); + scf::ReduceReturnOp::create(rewriter, loc, reductionResult); } rewriter.replaceOp(op, parOp.getResults()); return success(); @@ -278,7 +278,7 @@ class AffineIfLowering : public OpRewritePattern { // Now we just have to handle the condition logic. auto integerSet = op.getIntegerSet(); - Value zeroConstant = rewriter.create(loc, 0); + Value zeroConstant = arith::ConstantIndexOp::create(rewriter, loc, 0); SmallVector operands(op.getOperands()); auto operandsRef = llvm::ArrayRef(operands); @@ -298,18 +298,18 @@ class AffineIfLowering : public OpRewritePattern { auto pred = isEquality ? arith::CmpIPredicate::eq : arith::CmpIPredicate::sge; Value cmpVal = - rewriter.create(loc, pred, affResult, zeroConstant); - cond = cond - ? rewriter.create(loc, cond, cmpVal).getResult() - : cmpVal; + arith::CmpIOp::create(rewriter, loc, pred, affResult, zeroConstant); + cond = + cond ? arith::AndIOp::create(rewriter, loc, cond, cmpVal).getResult() + : cmpVal; } cond = cond ? cond - : rewriter.create(loc, /*value=*/1, - /*width=*/1); + : arith::ConstantIntOp::create(rewriter, loc, /*value=*/1, + /*width=*/1); bool hasElseRegion = !op.getElseRegion().empty(); - auto ifOp = rewriter.create(loc, op.getResultTypes(), cond, - hasElseRegion); + auto ifOp = scf::IfOp::create(rewriter, loc, op.getResultTypes(), cond, + hasElseRegion); rewriter.inlineRegionBefore(op.getThenRegion(), &ifOp.getThenRegion().back()); rewriter.eraseBlock(&ifOp.getThenRegion().back()); diff --git a/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp b/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp index 156c679c5039e..73a17b09721b2 100644 --- a/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp +++ b/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp @@ -115,9 +115,9 @@ static Value castF32To(Type desType, Value f32, Location loc, if (elementType.isF32()) return f32; if (elementType.getIntOrFloatBitWidth() < 32) - return rewriter.create(loc, desType, f32); + return arith::TruncFOp::create(rewriter, loc, desType, f32); if (elementType.getIntOrFloatBitWidth() > 32) - return rewriter.create(loc, desType, f32); + return arith::ExtFOp::create(rewriter, loc, desType, f32); llvm_unreachable("The only 32-bit float type is f32"); } @@ -139,27 +139,27 @@ ExtFOnFloat8RewritePattern::matchAndRewrite(arith::ExtFOp op, Type outElemType = getElementTypeOrSelf(op.getOut().getType()); VectorType extResType = VectorType::get(2, rewriter.getF32Type()); if (!inVecType) { - Value asFloat = rewriter.create( - loc, rewriter.getF32Type(), in, 0); + Value asFloat = amdgpu::ExtPackedFp8Op::create( + rewriter, loc, rewriter.getF32Type(), in, 0); Value result = castF32To(outElemType, asFloat, loc, rewriter); rewriter.replaceOp(op, result); return success(); } int64_t numElements = inVecType.getNumElements(); - Value zero = rewriter.create( - loc, outElemType, rewriter.getFloatAttr(outElemType, 0.0)); + Value zero = arith::ConstantOp::create( + rewriter, loc, outElemType, rewriter.getFloatAttr(outElemType, 0.0)); VectorType outType = cast(op.getOut().getType()); if (inVecType.getShape().empty()) { Value zerodSplat = rewriter.createOrFold(loc, outType, zero); Value scalarIn = - rewriter.create(loc, in, ArrayRef{}); + vector::ExtractOp::create(rewriter, loc, in, ArrayRef{}); Value scalarExt = - rewriter.create(loc, outElemType, scalarIn); - Value result = rewriter.create(loc, scalarExt, zerodSplat, - ArrayRef{}); + arith::ExtFOp::create(rewriter, loc, outElemType, scalarIn); + Value result = vector::InsertOp::create(rewriter, loc, scalarExt, + zerodSplat, ArrayRef{}); rewriter.replaceOp(op, result); return success(); } @@ -171,32 +171,32 @@ ExtFOnFloat8RewritePattern::matchAndRewrite(arith::ExtFOp op, if (inVecType.getRank() > 1) { inVecType = VectorType::get(SmallVector{numElements}, inVecType.getElementType()); - in = rewriter.create(loc, inVecType, in); + in = vector::ShapeCastOp::create(rewriter, loc, inVecType, in); } for (int64_t i = 0; i < numElements; i += 4) { int64_t elemsThisOp = std::min(numElements, i + 4) - i; - Value inSlice = rewriter.create( - loc, in, i, elemsThisOp, 1); + Value inSlice = vector::ExtractStridedSliceOp::create(rewriter, loc, in, i, + elemsThisOp, 1); for (int64_t j = 0; j < elemsThisOp; j += 2) { if (i + j + 1 < numElements) { // Convert two 8-bit elements - Value asFloats = rewriter.create( - loc, extResType, inSlice, j / 2); + Value asFloats = amdgpu::ExtPackedFp8Op::create( + rewriter, loc, extResType, inSlice, j / 2); Type desType = VectorType::get(2, outElemType); Value asType = castF32To(desType, asFloats, loc, rewriter); - result = rewriter.create( - loc, asType, result, i + j, 1); + result = vector::InsertStridedSliceOp::create(rewriter, loc, asType, + result, i + j, 1); } else { // Convert a 8-bit element - Value asFloat = rewriter.create( - loc, rewriter.getF32Type(), inSlice, j / 2 * 2); + Value asFloat = amdgpu::ExtPackedFp8Op::create( + rewriter, loc, rewriter.getF32Type(), inSlice, j / 2 * 2); Value asType = castF32To(outElemType, asFloat, loc, rewriter); - result = rewriter.create(loc, asType, result, i + j); + result = vector::InsertOp::create(rewriter, loc, asType, result, i + j); } } } if (inVecType.getRank() != outType.getRank()) { - result = rewriter.create(loc, outType, result); + result = vector::ShapeCastOp::create(rewriter, loc, outType, result); } rewriter.replaceOp(op, result); @@ -208,9 +208,9 @@ static Value castToF32(Value value, Location loc, PatternRewriter &rewriter) { if (type.isF32()) return value; if (type.getIntOrFloatBitWidth() < 32) - return rewriter.create(loc, rewriter.getF32Type(), value); + return arith::ExtFOp::create(rewriter, loc, rewriter.getF32Type(), value); if (type.getIntOrFloatBitWidth() > 32) - return rewriter.create(loc, rewriter.getF32Type(), value); + return arith::TruncFOp::create(rewriter, loc, rewriter.getF32Type(), value); llvm_unreachable("The only 32-bit float type is f32"); } @@ -250,13 +250,15 @@ static Value clampInput(PatternRewriter &rewriter, Location loc, loc, arith::CmpFPredicate::OEQ, source, negInf); Value isNan = rewriter.createOrFold( loc, arith::CmpFPredicate::UNO, source, source); - Value isNonFinite = rewriter.create( - loc, rewriter.create(loc, isInf, isNegInf), isNan); + Value isNonFinite = arith::OrIOp::create( + rewriter, loc, arith::OrIOp::create(rewriter, loc, isInf, isNegInf), + isNan); - Value clampedBelow = rewriter.create(loc, source, minCst); - Value clamped = rewriter.create(loc, clampedBelow, maxCst); + Value clampedBelow = arith::MaximumFOp::create(rewriter, loc, source, minCst); + Value clamped = + arith::MinimumFOp::create(rewriter, loc, clampedBelow, maxCst); Value res = - rewriter.create(loc, isNonFinite, source, clamped); + arith::SelectOp::create(rewriter, loc, isNonFinite, source, clamped); return res; } @@ -290,25 +292,25 @@ TruncFToFloat8RewritePattern::matchAndRewrite(arith::TruncFOp op, VectorType truncResType = VectorType::get(4, outElemType); if (!inVectorTy) { Value asFloat = castToF32(in, loc, rewriter); - Value asF8s = rewriter.create( - loc, truncResType, asFloat, /*sourceB=*/nullptr, 0, + Value asF8s = amdgpu::PackedTrunc2xFp8Op::create( + rewriter, loc, truncResType, asFloat, /*sourceB=*/nullptr, 0, /*existing=*/nullptr); - Value result = rewriter.create(loc, asF8s, 0); + Value result = vector::ExtractOp::create(rewriter, loc, asF8s, 0); rewriter.replaceOp(op, result); return success(); } int64_t numElements = outVecType.getNumElements(); - Value zero = rewriter.create( - loc, outElemType, rewriter.getFloatAttr(outElemType, 0.0)); + Value zero = arith::ConstantOp::create( + rewriter, loc, outElemType, rewriter.getFloatAttr(outElemType, 0.0)); if (outVecType.getShape().empty()) { Value scalarIn = - rewriter.create(loc, in, ArrayRef{}); + vector::ExtractOp::create(rewriter, loc, in, ArrayRef{}); // Recurse to send the 0-D vector case to the 1-D vector case Value scalarTrunc = - rewriter.create(loc, outElemType, scalarIn); - Value result = rewriter.create(loc, scalarTrunc, zero, - ArrayRef{}); + arith::TruncFOp::create(rewriter, loc, outElemType, scalarIn); + Value result = vector::InsertOp::create(rewriter, loc, scalarTrunc, zero, + ArrayRef{}); rewriter.replaceOp(op, result); return success(); } @@ -320,32 +322,32 @@ TruncFToFloat8RewritePattern::matchAndRewrite(arith::TruncFOp op, if (inVectorTy.getRank() > 1) { inVectorTy = VectorType::get(SmallVector{numElements}, inVectorTy.getElementType()); - in = rewriter.create(loc, inVectorTy, in); + in = vector::ShapeCastOp::create(rewriter, loc, inVectorTy, in); } for (int64_t i = 0; i < numElements; i += 4) { int64_t elemsThisOp = std::min(numElements, i + 4) - i; Value thisResult = nullptr; for (int64_t j = 0; j < elemsThisOp; j += 2) { - Value elemA = rewriter.create(loc, in, i + j); + Value elemA = vector::ExtractOp::create(rewriter, loc, in, i + j); Value asFloatA = castToF32(elemA, loc, rewriter); Value asFloatB = nullptr; if (j + 1 < elemsThisOp) { - Value elemB = rewriter.create(loc, in, i + j + 1); + Value elemB = vector::ExtractOp::create(rewriter, loc, in, i + j + 1); asFloatB = castToF32(elemB, loc, rewriter); } - thisResult = rewriter.create( - loc, truncResType, asFloatA, asFloatB, j / 2, thisResult); + thisResult = amdgpu::PackedTrunc2xFp8Op::create( + rewriter, loc, truncResType, asFloatA, asFloatB, j / 2, thisResult); } if (elemsThisOp < 4) - thisResult = rewriter.create( - loc, thisResult, 0, elemsThisOp, 1); - result = rewriter.create(loc, thisResult, - result, i, 1); + thisResult = vector::ExtractStridedSliceOp::create( + rewriter, loc, thisResult, 0, elemsThisOp, 1); + result = vector::InsertStridedSliceOp::create(rewriter, loc, thisResult, + result, i, 1); } if (inVectorTy.getRank() != outVecType.getRank()) { - result = rewriter.create(loc, outVecType, result); + result = vector::ShapeCastOp::create(rewriter, loc, outVecType, result); } rewriter.replaceOp(op, result); @@ -373,10 +375,10 @@ LogicalResult TruncfToFloat16RewritePattern::matchAndRewrite( // Handle the case where input type is not a vector type if (!inVectorTy) { - auto sourceB = rewriter.create(loc, rewriter.getF32Type()); + auto sourceB = LLVM::PoisonOp::create(rewriter, loc, rewriter.getF32Type()); Value asF16s = - rewriter.create(loc, truncResType, in, sourceB); - Value result = rewriter.create(loc, asF16s, 0); + ROCDL::CvtPkRtz::create(rewriter, loc, truncResType, in, sourceB); + Value result = vector::ExtractOp::create(rewriter, loc, asF16s, 0); rewriter.replaceOp(op, result); return success(); } @@ -389,7 +391,7 @@ LogicalResult TruncfToFloat16RewritePattern::matchAndRewrite( if (inVectorTy.getRank() > 1) { inVectorTy = VectorType::get(SmallVector{numElements}, inVectorTy.getElementType()); - in = rewriter.create(loc, inVectorTy, in); + in = vector::ShapeCastOp::create(rewriter, loc, inVectorTy, in); } // Handle the vector case. We also handle the (uncommon) case where the vector @@ -397,25 +399,25 @@ LogicalResult TruncfToFloat16RewritePattern::matchAndRewrite( for (int64_t i = 0; i < numElements; i += 2) { int64_t elemsThisOp = std::min(numElements, i + 2) - i; Value thisResult = nullptr; - Value elemA = rewriter.create(loc, in, i); - Value elemB = rewriter.create(loc, rewriter.getF32Type()); + Value elemA = vector::ExtractOp::create(rewriter, loc, in, i); + Value elemB = LLVM::PoisonOp::create(rewriter, loc, rewriter.getF32Type()); if (elemsThisOp == 2) { - elemB = rewriter.create(loc, in, i + 1); + elemB = vector::ExtractOp::create(rewriter, loc, in, i + 1); } thisResult = - rewriter.create(loc, truncResType, elemA, elemB); + ROCDL::CvtPkRtz::create(rewriter, loc, truncResType, elemA, elemB); // Place back the truncated result into the possibly larger vector. If we // are operating on a size 2 vector, these operations should be folded away - thisResult = rewriter.create( - loc, thisResult, 0, elemsThisOp, 1); - result = rewriter.create(loc, thisResult, - result, i, 1); + thisResult = vector::ExtractStridedSliceOp::create( + rewriter, loc, thisResult, 0, elemsThisOp, 1); + result = vector::InsertStridedSliceOp::create(rewriter, loc, thisResult, + result, i, 1); } if (inVectorTy.getRank() != outVecType.getRank()) { - result = rewriter.create(loc, outVecType, result); + result = vector::ShapeCastOp::create(rewriter, loc, outVecType, result); } rewriter.replaceOp(op, result); @@ -472,18 +474,18 @@ ScalingExtFRewritePattern::matchAndRewrite(arith::ScalingExtFOp op, Type scaleF32Type = scaleVecType ? VectorType::get(scaleVecType.getShape(), f32) : f32; if (scaleType.getIntOrFloatBitWidth() < 32) - scale = rewriter.create(loc, scaleF32Type, scale); + scale = arith::ExtFOp::create(rewriter, loc, scaleF32Type, scale); else if (scaleType.getIntOrFloatBitWidth() > 32) - scale = rewriter.create(loc, scaleF32Type, scale); + scale = arith::TruncFOp::create(rewriter, loc, scaleF32Type, scale); VectorType extScaleResultType = VectorType::get(opWidth, outType); if (!outVecType) { - Value inCast = rewriter.create( - loc, VectorType::get(1, inType), in); + Value inCast = vector::BroadcastOp::create(rewriter, loc, + VectorType::get(1, inType), in); // TODO: replace this with non-packed ScaledExtOp - Value scaleExt = rewriter.create( - loc, extScaleResultType, inCast, scale, 0); + Value scaleExt = amdgpu::ScaledExtPackedOp::create( + rewriter, loc, extScaleResultType, inCast, scale, 0); scaleExt = rewriter.replaceOpWithNewOp(op, scaleExt, 0); return success(); } @@ -508,20 +510,20 @@ ScalingExtFRewritePattern::matchAndRewrite(arith::ScalingExtFOp op, int64_t blockSize = computeProduct(ratio); - Value zero = rewriter.create( - loc, outType, rewriter.getFloatAttr(outType, 0.0)); + Value zero = arith::ConstantOp::create(rewriter, loc, outType, + rewriter.getFloatAttr(outType, 0.0)); Value result = rewriter.createOrFold(loc, outVecType, zero); for (SmallVector offsets : StaticTileOffsetRange(inShape, ratio)) { SmallVector strides(offsets.size(), 1); - Value block = rewriter.create( - loc, in, offsets, ratio, strides); + Value block = vector::ExtractStridedSliceOp::create( + rewriter, loc, in, offsets, ratio, strides); VectorType block1DType = VectorType::get(blockSize, inType); Value block1D = - rewriter.create(loc, block1DType, block); + vector::ShapeCastOp::create(rewriter, loc, block1DType, block); Value uniformScale = - rewriter.create(loc, scale, offsets); + vector::ExtractOp::create(rewriter, loc, scale, offsets); VectorType blockResultType = VectorType::get(blockSize, outType); Value blockResult = @@ -530,23 +532,23 @@ ScalingExtFRewritePattern::matchAndRewrite(arith::ScalingExtFOp op, for (int64_t i = 0, sliceWidth = std::min(opWidth, blockSize - i); i < blockSize; i += sliceWidth, sliceWidth = std::min(opWidth, blockSize - i)) { - Value slice = rewriter.create( - loc, block1D, i, sliceWidth, 1); + Value slice = vector::ExtractStridedSliceOp::create( + rewriter, loc, block1D, i, sliceWidth, 1); // TODO: replace this with non-packed ScaledExtOp for sliceWidth == 1 - Value scaleExt = rewriter.create( - loc, extScaleResultType, slice, uniformScale, 0); + Value scaleExt = amdgpu::ScaledExtPackedOp::create( + rewriter, loc, extScaleResultType, slice, uniformScale, 0); if (sliceWidth != opWidth) - scaleExt = rewriter.create( - loc, scaleExt, 0, sliceWidth, 1); - blockResult = rewriter.create( - loc, scaleExt, blockResult, i, 1); + scaleExt = vector::ExtractStridedSliceOp::create( + rewriter, loc, scaleExt, 0, sliceWidth, 1); + blockResult = vector::InsertStridedSliceOp::create( + rewriter, loc, scaleExt, blockResult, i, 1); } VectorType resultType = VectorType::get(ratio, outType); Value cast = - rewriter.create(loc, resultType, blockResult); - result = rewriter.create(loc, cast, result, - offsets, strides); + vector::ShapeCastOp::create(rewriter, loc, resultType, blockResult); + result = vector::InsertStridedSliceOp::create(rewriter, loc, cast, result, + offsets, strides); } rewriter.replaceOp(op, result); @@ -578,21 +580,22 @@ ScalingTruncFRewritePattern::matchAndRewrite(arith::ScalingTruncFOp op, Type scaleF32Type = scaleVecType ? VectorType::get(scaleVecType.getShape(), f32) : f32; if (scaleType.getIntOrFloatBitWidth() < 32) - scale = rewriter.create(loc, scaleF32Type, scale); + scale = arith::ExtFOp::create(rewriter, loc, scaleF32Type, scale); else if (scaleType.getIntOrFloatBitWidth() > 32) - scale = rewriter.create(loc, scaleF32Type, scale); + scale = arith::TruncFOp::create(rewriter, loc, scaleF32Type, scale); - Value zero = rewriter.create( - loc, outType, rewriter.getFloatAttr(outType, 0.0)); + Value zero = arith::ConstantOp::create(rewriter, loc, outType, + rewriter.getFloatAttr(outType, 0.0)); unsigned numPackedElem = 32 / outType.getIntOrFloatBitWidth(); VectorType truncScaleResultType = VectorType::get(numPackedElem, outType); if (!outVecType) { Type inVecType = VectorType::get(1, inType); - Value inCast = rewriter.create(loc, inVecType, in); + Value inCast = vector::BroadcastOp::create(rewriter, loc, inVecType, in); // TODO: replace this with non-packed ScaledTruncOp - Value scaleTrunc = rewriter.create( - loc, truncScaleResultType, inCast, scale, 0, /*existing=*/nullptr); + Value scaleTrunc = amdgpu::PackedScaledTruncOp::create( + rewriter, loc, truncScaleResultType, inCast, scale, 0, + /*existing=*/nullptr); scaleTrunc = rewriter.replaceOpWithNewOp(op, scaleTrunc, 0); return success(); @@ -623,13 +626,13 @@ ScalingTruncFRewritePattern::matchAndRewrite(arith::ScalingTruncFOp op, for (SmallVector offsets : StaticTileOffsetRange(inShape, ratio)) { SmallVector strides(offsets.size(), 1); - Value block = rewriter.create( - loc, in, offsets, ratio, strides); + Value block = vector::ExtractStridedSliceOp::create( + rewriter, loc, in, offsets, ratio, strides); VectorType block1DType = VectorType::get(blockSize, inType); Value block1D = - rewriter.create(loc, block1DType, block); + vector::ShapeCastOp::create(rewriter, loc, block1DType, block); Value uniformScale = - rewriter.create(loc, scale, offsets); + vector::ExtractOp::create(rewriter, loc, scale, offsets); VectorType blockResultType = VectorType::get(blockSize, outType); Value blockResult = @@ -638,26 +641,26 @@ ScalingTruncFRewritePattern::matchAndRewrite(arith::ScalingTruncFOp op, for (int64_t i = 0, sliceWidth = std::min(opWidth, blockSize - i); i < blockSize; i += sliceWidth, sliceWidth = std::min(opWidth, blockSize - i)) { - Value slice = rewriter.create( - loc, block1D, i, sliceWidth, 1); + Value slice = vector::ExtractStridedSliceOp::create( + rewriter, loc, block1D, i, sliceWidth, 1); // TODO: replace this with non-packed ScaledTruncOp for sliceWidth == 1 - Value scaleTrunc = rewriter.create( - loc, truncScaleResultType, slice, uniformScale, 0, + Value scaleTrunc = amdgpu::PackedScaledTruncOp::create( + rewriter, loc, truncScaleResultType, slice, uniformScale, 0, /*existing=*/nullptr); int64_t packedWidth = cast(scaleTrunc.getType()).getNumElements(); if (packedWidth != opWidth) - scaleTrunc = rewriter.create( - loc, scaleTrunc, 0, sliceWidth, 1); - blockResult = rewriter.create( - loc, scaleTrunc, blockResult, i, 1); + scaleTrunc = vector::ExtractStridedSliceOp::create( + rewriter, loc, scaleTrunc, 0, sliceWidth, 1); + blockResult = vector::InsertStridedSliceOp::create( + rewriter, loc, scaleTrunc, blockResult, i, 1); } VectorType resultType = VectorType::get(ratio, outType); Value cast = - rewriter.create(loc, resultType, blockResult); - result = rewriter.create(loc, cast, result, - offsets, strides); + vector::ShapeCastOp::create(rewriter, loc, resultType, blockResult); + result = vector::InsertStridedSliceOp::create(rewriter, loc, cast, result, + offsets, strides); } rewriter.replaceOp(op, result); diff --git a/mlir/lib/Conversion/ArithToArmSME/ArithToArmSME.cpp b/mlir/lib/Conversion/ArithToArmSME/ArithToArmSME.cpp index cbe0b3fda3410..ba489436a1a4d 100644 --- a/mlir/lib/Conversion/ArithToArmSME/ArithToArmSME.cpp +++ b/mlir/lib/Conversion/ArithToArmSME/ArithToArmSME.cpp @@ -74,15 +74,15 @@ struct ConstantOpToArmSMELowering : public OpRewritePattern { VectorType tileSliceType = VectorType::Builder(tileType).dropDim(0); auto denseAttr1D = DenseElementsAttr::get( tileSliceType, denseAttr.getSplatValue()); - auto constantOp1D = rewriter.create(loc, denseAttr1D); + auto constantOp1D = arith::ConstantOp::create(rewriter, loc, denseAttr1D); - auto initTile = rewriter.create(loc, tileType); + auto initTile = arm_sme::GetTileOp::create(rewriter, loc, tileType); auto makeLoopBody = [&](OpBuilder &b, Location loc, Value tileSliceIndex, Value currentTile) { // Create 'arm_sme.insert_tile_slice' to write vector to tile // slice. - auto nextTile = b.create( - loc, tileType, constantOp1D, currentTile, tileSliceIndex); + auto nextTile = arm_sme::InsertTileSliceOp::create( + b, loc, tileType, constantOp1D, currentTile, tileSliceIndex); return nextTile.getResult(); }; auto forOp = mlir::arm_sme::createLoopOverTileSlices( diff --git a/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp b/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp index a5c08a6378021..59b3fe2e4eaed 100644 --- a/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp +++ b/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp @@ -110,9 +110,9 @@ class CmpFOpConversion : public OpConversionPattern { emitc::CmpPredicate predicate; switch (op.getPredicate()) { case arith::CmpFPredicate::AlwaysFalse: { - auto constant = rewriter.create( - op.getLoc(), rewriter.getI1Type(), - rewriter.getBoolAttr(/*value=*/false)); + auto constant = + emitc::ConstantOp::create(rewriter, op.getLoc(), rewriter.getI1Type(), + rewriter.getBoolAttr(/*value=*/false)); rewriter.replaceOp(op, constant); return success(); } @@ -179,9 +179,9 @@ class CmpFOpConversion : public OpConversionPattern { return success(); } case arith::CmpFPredicate::AlwaysTrue: { - auto constant = rewriter.create( - op.getLoc(), rewriter.getI1Type(), - rewriter.getBoolAttr(/*value=*/true)); + auto constant = + emitc::ConstantOp::create(rewriter, op.getLoc(), rewriter.getI1Type(), + rewriter.getBoolAttr(/*value=*/true)); rewriter.replaceOp(op, constant); return success(); } @@ -189,8 +189,8 @@ class CmpFOpConversion : public OpConversionPattern { // Compare the values naively auto cmpResult = - rewriter.create(op.getLoc(), op.getType(), predicate, - adaptor.getLhs(), adaptor.getRhs()); + emitc::CmpOp::create(rewriter, op.getLoc(), op.getType(), predicate, + adaptor.getLhs(), adaptor.getRhs()); // Adjust the results for unordered/ordered semantics if (unordered) { @@ -213,16 +213,16 @@ class CmpFOpConversion : public OpConversionPattern { Value isNaN(ConversionPatternRewriter &rewriter, Location loc, Value operand) const { // A value is NaN exactly when it compares unequal to itself. - return rewriter.create( - loc, rewriter.getI1Type(), emitc::CmpPredicate::ne, operand, operand); + return emitc::CmpOp::create(rewriter, loc, rewriter.getI1Type(), + emitc::CmpPredicate::ne, operand, operand); } /// Return a value that is true if \p operand is not NaN. Value isNotNaN(ConversionPatternRewriter &rewriter, Location loc, Value operand) const { // A value is not NaN exactly when it compares equal to itself. - return rewriter.create( - loc, rewriter.getI1Type(), emitc::CmpPredicate::eq, operand, operand); + return emitc::CmpOp::create(rewriter, loc, rewriter.getI1Type(), + emitc::CmpPredicate::eq, operand, operand); } /// Return a value that is true if the operands \p first and \p second are @@ -231,8 +231,8 @@ class CmpFOpConversion : public OpConversionPattern { Location loc, Value first, Value second) const { auto firstIsNaN = isNaN(rewriter, loc, first); auto secondIsNaN = isNaN(rewriter, loc, second); - return rewriter.create(loc, rewriter.getI1Type(), - firstIsNaN, secondIsNaN); + return emitc::LogicalOrOp::create(rewriter, loc, rewriter.getI1Type(), + firstIsNaN, secondIsNaN); } /// Return a value that is true if the operands \p first and \p second are @@ -241,8 +241,8 @@ class CmpFOpConversion : public OpConversionPattern { Value first, Value second) const { auto firstIsNotNaN = isNotNaN(rewriter, loc, first); auto secondIsNotNaN = isNotNaN(rewriter, loc, second); - return rewriter.create(loc, rewriter.getI1Type(), - firstIsNotNaN, secondIsNotNaN); + return emitc::LogicalAndOp::create(rewriter, loc, rewriter.getI1Type(), + firstIsNotNaN, secondIsNotNaN); } }; @@ -378,10 +378,10 @@ class CastConversion : public OpConversionPattern { Type attrType = (emitc::isPointerWideType(operandType)) ? rewriter.getIndexType() : operandType; - auto constOne = rewriter.create( - op.getLoc(), operandType, rewriter.getOneAttr(attrType)); - auto oneAndOperand = rewriter.create( - op.getLoc(), operandType, adaptor.getIn(), constOne); + auto constOne = emitc::ConstantOp::create( + rewriter, op.getLoc(), operandType, rewriter.getOneAttr(attrType)); + auto oneAndOperand = emitc::BitwiseAndOp::create( + rewriter, op.getLoc(), operandType, adaptor.getIn(), constOne); rewriter.replaceOpWithNewOp(op, opReturnType, oneAndOperand); return success(); @@ -466,9 +466,8 @@ class BinaryUIOpConversion final : public OpConversionPattern { Value lhsAdapted = adaptValueType(uiBinOp.getLhs(), rewriter, unsignedType); Value rhsAdapted = adaptValueType(uiBinOp.getRhs(), rewriter, unsignedType); - auto newDivOp = - rewriter.create(uiBinOp.getLoc(), unsignedType, - ArrayRef{lhsAdapted, rhsAdapted}); + auto newDivOp = EmitCOp::create(rewriter, uiBinOp.getLoc(), unsignedType, + ArrayRef{lhsAdapted, rhsAdapted}); Value resultAdapted = adaptValueType(newDivOp, rewriter, newRetTy); rewriter.replaceOp(uiBinOp, resultAdapted); return success(); @@ -588,38 +587,40 @@ class ShiftOpConversion : public OpConversionPattern { // Add a runtime check for overflow Value width; if (emitc::isPointerWideType(type)) { - Value eight = rewriter.create( - op.getLoc(), rhsType, rewriter.getIndexAttr(8)); - emitc::CallOpaqueOp sizeOfCall = rewriter.create( - op.getLoc(), rhsType, "sizeof", ArrayRef{eight}); - width = rewriter.create(op.getLoc(), rhsType, eight, - sizeOfCall.getResult(0)); + Value eight = emitc::ConstantOp::create(rewriter, op.getLoc(), rhsType, + rewriter.getIndexAttr(8)); + emitc::CallOpaqueOp sizeOfCall = emitc::CallOpaqueOp::create( + rewriter, op.getLoc(), rhsType, "sizeof", ArrayRef{eight}); + width = emitc::MulOp::create(rewriter, op.getLoc(), rhsType, eight, + sizeOfCall.getResult(0)); } else { - width = rewriter.create( - op.getLoc(), rhsType, + width = emitc::ConstantOp::create( + rewriter, op.getLoc(), rhsType, rewriter.getIntegerAttr(rhsType, type.getIntOrFloatBitWidth())); } - Value excessCheck = rewriter.create( - op.getLoc(), rewriter.getI1Type(), emitc::CmpPredicate::lt, rhs, width); + Value excessCheck = + emitc::CmpOp::create(rewriter, op.getLoc(), rewriter.getI1Type(), + emitc::CmpPredicate::lt, rhs, width); // Any concrete value is a valid refinement of poison. - Value poison = rewriter.create( - op.getLoc(), arithmeticType, + Value poison = emitc::ConstantOp::create( + rewriter, op.getLoc(), arithmeticType, (isa(arithmeticType) ? rewriter.getIntegerAttr(arithmeticType, 0) : rewriter.getIndexAttr(0))); - emitc::ExpressionOp ternary = rewriter.create( - op.getLoc(), arithmeticType, /*do_not_inline=*/false); + emitc::ExpressionOp ternary = emitc::ExpressionOp::create( + rewriter, op.getLoc(), arithmeticType, /*do_not_inline=*/false); Block &bodyBlock = ternary.getBodyRegion().emplaceBlock(); auto currentPoint = rewriter.getInsertionPoint(); rewriter.setInsertionPointToStart(&bodyBlock); Value arithmeticResult = - rewriter.create(op.getLoc(), arithmeticType, lhs, rhs); - Value resultOrPoison = rewriter.create( - op.getLoc(), arithmeticType, excessCheck, arithmeticResult, poison); - rewriter.create(op.getLoc(), resultOrPoison); + EmitCOp::create(rewriter, op.getLoc(), arithmeticType, lhs, rhs); + Value resultOrPoison = + emitc::ConditionalOp::create(rewriter, op.getLoc(), arithmeticType, + excessCheck, arithmeticResult, poison); + emitc::YieldOp::create(rewriter, op.getLoc(), resultOrPoison); rewriter.setInsertionPoint(op->getBlock(), currentPoint); Value result = adaptValueType(ternary, rewriter, type); @@ -700,11 +701,12 @@ class FtoICastOpConversion : public OpConversionPattern { /*isSigned=*/false); } - Value result = rewriter.create( - castOp.getLoc(), actualResultType, adaptor.getOperands()); + Value result = emitc::CastOp::create( + rewriter, castOp.getLoc(), actualResultType, adaptor.getOperands()); if (isa(castOp)) { - result = rewriter.create(castOp.getLoc(), dstType, result); + result = + emitc::CastOp::create(rewriter, castOp.getLoc(), dstType, result); } rewriter.replaceOp(castOp, result); diff --git a/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp b/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp index f7bf581adc9e3..18e857c81af8d 100644 --- a/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp +++ b/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp @@ -293,11 +293,11 @@ LogicalResult IndexCastOpLowering::matchAndRewrite( [&](Type llvm1DVectorTy, ValueRange operands) -> Value { typename OpTy::Adaptor adaptor(operands); if (targetBits < sourceBits) { - return rewriter.create(op.getLoc(), llvm1DVectorTy, - adaptor.getIn()); + return LLVM::TruncOp::create(rewriter, op.getLoc(), llvm1DVectorTy, + adaptor.getIn()); } - return rewriter.create(op.getLoc(), llvm1DVectorTy, - adaptor.getIn()); + return ExtCastTy::create(rewriter, op.getLoc(), llvm1DVectorTy, + adaptor.getIn()); }, rewriter); } @@ -324,12 +324,12 @@ LogicalResult AddUIExtendedOpLowering::matchAndRewrite( Type newOverflowType = typeConverter->convertType(overflowResultType); Type structType = LLVM::LLVMStructType::getLiteral(ctx, {sumResultType, newOverflowType}); - Value addOverflow = rewriter.create( - loc, structType, adaptor.getLhs(), adaptor.getRhs()); + Value addOverflow = LLVM::UAddWithOverflowOp::create( + rewriter, loc, structType, adaptor.getLhs(), adaptor.getRhs()); Value sumExtracted = - rewriter.create(loc, addOverflow, 0); + LLVM::ExtractValueOp::create(rewriter, loc, addOverflow, 0); Value overflowExtracted = - rewriter.create(loc, addOverflow, 1); + LLVM::ExtractValueOp::create(rewriter, loc, addOverflow, 1); rewriter.replaceOp(op, {sumExtracted, overflowExtracted}); return success(); } @@ -381,15 +381,15 @@ LogicalResult MulIExtendedOpLowering::matchAndRewrite( "LLVM dialect should support all signless integer types"); using LLVMExtOp = std::conditional_t; - Value lhsExt = rewriter.create(loc, wideType, adaptor.getLhs()); - Value rhsExt = rewriter.create(loc, wideType, adaptor.getRhs()); - Value mulExt = rewriter.create(loc, wideType, lhsExt, rhsExt); + Value lhsExt = LLVMExtOp::create(rewriter, loc, wideType, adaptor.getLhs()); + Value rhsExt = LLVMExtOp::create(rewriter, loc, wideType, adaptor.getRhs()); + Value mulExt = LLVM::MulOp::create(rewriter, loc, wideType, lhsExt, rhsExt); // Split the 2*N-bit wide result into two N-bit values. - Value low = rewriter.create(loc, resultType, mulExt); - Value shiftVal = rewriter.create(loc, shiftValAttr); - Value highExt = rewriter.create(loc, mulExt, shiftVal); - Value high = rewriter.create(loc, resultType, highExt); + Value low = LLVM::TruncOp::create(rewriter, loc, resultType, mulExt); + Value shiftVal = LLVM::ConstantOp::create(rewriter, loc, shiftValAttr); + Value highExt = LLVM::LShrOp::create(rewriter, loc, mulExt, shiftVal); + Value high = LLVM::TruncOp::create(rewriter, loc, resultType, highExt); rewriter.replaceOp(op, {low, high}); return success(); @@ -435,8 +435,8 @@ CmpIOpLowering::matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor, op.getOperation(), adaptor.getOperands(), *getTypeConverter(), [&](Type llvm1DVectorTy, ValueRange operands) { OpAdaptor adaptor(operands); - return rewriter.create( - op.getLoc(), llvm1DVectorTy, + return LLVM::ICmpOp::create( + rewriter, op.getLoc(), llvm1DVectorTy, convertCmpPredicate(op.getPredicate()), adaptor.getLhs(), adaptor.getRhs()); }, @@ -471,8 +471,8 @@ CmpFOpLowering::matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor, op.getOperation(), adaptor.getOperands(), *getTypeConverter(), [&](Type llvm1DVectorTy, ValueRange operands) { OpAdaptor adaptor(operands); - return rewriter.create( - op.getLoc(), llvm1DVectorTy, + return LLVM::FCmpOp::create( + rewriter, op.getLoc(), llvm1DVectorTy, convertCmpPredicate(op.getPredicate()), adaptor.getLhs(), adaptor.getRhs(), fmf); }, diff --git a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp index 434d7df853a5e..d43e6816641cb 100644 --- a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp +++ b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp @@ -117,12 +117,12 @@ static Value getScalarOrVectorConstInt(Type type, uint64_t value, if (auto vectorType = dyn_cast(type)) { Attribute element = IntegerAttr::get(vectorType.getElementType(), value); auto attr = SplatElementsAttr::get(vectorType, element); - return builder.create(loc, vectorType, attr); + return spirv::ConstantOp::create(builder, loc, vectorType, attr); } if (auto intType = dyn_cast(type)) - return builder.create( - loc, type, builder.getIntegerAttr(type, value)); + return spirv::ConstantOp::create(builder, loc, type, + builder.getIntegerAttr(type, value)); return nullptr; } @@ -418,18 +418,19 @@ static Value emulateSignedRemainder(Location loc, Value lhs, Value rhs, Type type = lhs.getType(); // Calculate the remainder with spirv.UMod. - Value lhsAbs = builder.create(loc, type, lhs); - Value rhsAbs = builder.create(loc, type, rhs); - Value abs = builder.create(loc, lhsAbs, rhsAbs); + Value lhsAbs = SignedAbsOp::create(builder, loc, type, lhs); + Value rhsAbs = SignedAbsOp::create(builder, loc, type, rhs); + Value abs = spirv::UModOp::create(builder, loc, lhsAbs, rhsAbs); // Fix the sign. Value isPositive; if (lhs == signOperand) - isPositive = builder.create(loc, lhs, lhsAbs); + isPositive = spirv::IEqualOp::create(builder, loc, lhs, lhsAbs); else - isPositive = builder.create(loc, rhs, rhsAbs); - Value absNegate = builder.create(loc, type, abs); - return builder.create(loc, type, isPositive, abs, absNegate); + isPositive = spirv::IEqualOp::create(builder, loc, rhs, rhsAbs); + Value absNegate = spirv::SNegateOp::create(builder, loc, type, abs); + return spirv::SelectOp::create(builder, loc, type, isPositive, abs, + absNegate); } /// Converts arith.remsi to GLSL SPIR-V ops. @@ -601,13 +602,13 @@ struct ExtSII1Pattern final : public OpConversionPattern { Value allOnes; if (auto intTy = dyn_cast(dstType)) { unsigned componentBitwidth = intTy.getWidth(); - allOnes = rewriter.create( - loc, intTy, + allOnes = spirv::ConstantOp::create( + rewriter, loc, intTy, rewriter.getIntegerAttr(intTy, APInt::getAllOnes(componentBitwidth))); } else if (auto vectorTy = dyn_cast(dstType)) { unsigned componentBitwidth = vectorTy.getElementTypeBitWidth(); - allOnes = rewriter.create( - loc, vectorTy, + allOnes = spirv::ConstantOp::create( + rewriter, loc, vectorTy, SplatElementsAttr::get(vectorTy, APInt::getAllOnes(componentBitwidth))); } else { @@ -653,8 +654,8 @@ struct ExtSIPattern final : public OpConversionPattern { // First shift left to sequeeze out all leading bits beyond the original // bitwidth. Here we need to use the original source and result type's // bitwidth. - auto shiftLOp = rewriter.create( - op.getLoc(), dstType, adaptor.getIn(), shiftSize); + auto shiftLOp = spirv::ShiftLeftLogicalOp::create( + rewriter, op.getLoc(), dstType, adaptor.getIn(), shiftSize); // Then we perform arithmetic right shift to make sure we have the right // sign bits for negative values. @@ -757,9 +758,9 @@ struct TruncII1Pattern final : public OpConversionPattern { auto srcType = adaptor.getOperands().front().getType(); // Check if (x & 1) == 1. Value mask = spirv::ConstantOp::getOne(srcType, loc, rewriter); - Value maskedSrc = rewriter.create( - loc, srcType, adaptor.getOperands()[0], mask); - Value isOne = rewriter.create(loc, maskedSrc, mask); + Value maskedSrc = spirv::BitwiseAndOp::create( + rewriter, loc, srcType, adaptor.getOperands()[0], mask); + Value isOne = spirv::IEqualOp::create(rewriter, loc, maskedSrc, mask); Value zero = spirv::ConstantOp::getZero(dstType, loc, rewriter); Value one = spirv::ConstantOp::getOne(dstType, loc, rewriter); @@ -914,9 +915,9 @@ class CmpIOpBooleanPattern final : public OpConversionPattern { if (auto vectorType = dyn_cast(dstType)) type = VectorType::get(vectorType.getShape(), type); Value extLhs = - rewriter.create(op.getLoc(), type, adaptor.getLhs()); + arith::ExtUIOp::create(rewriter, op.getLoc(), type, adaptor.getLhs()); Value extRhs = - rewriter.create(op.getLoc(), type, adaptor.getRhs()); + arith::ExtUIOp::create(rewriter, op.getLoc(), type, adaptor.getRhs()); rewriter.replaceOpWithNewOp(op, op.getPredicate(), extLhs, extRhs); @@ -1067,12 +1068,12 @@ class CmpFOpNanNonePattern final : public OpConversionPattern { replace = spirv::ConstantOp::getZero(op.getType(), loc, rewriter); } } else { - Value lhsIsNan = rewriter.create(loc, adaptor.getLhs()); - Value rhsIsNan = rewriter.create(loc, adaptor.getRhs()); + Value lhsIsNan = spirv::IsNanOp::create(rewriter, loc, adaptor.getLhs()); + Value rhsIsNan = spirv::IsNanOp::create(rewriter, loc, adaptor.getRhs()); - replace = rewriter.create(loc, lhsIsNan, rhsIsNan); + replace = spirv::LogicalOrOp::create(rewriter, loc, lhsIsNan, rhsIsNan); if (op.getPredicate() == arith::CmpFPredicate::ORD) - replace = rewriter.create(loc, replace); + replace = spirv::LogicalNotOp::create(rewriter, loc, replace); } rewriter.replaceOp(op, replace); @@ -1094,17 +1095,17 @@ class AddUIExtendedOpPattern final ConversionPatternRewriter &rewriter) const override { Type dstElemTy = adaptor.getLhs().getType(); Location loc = op->getLoc(); - Value result = rewriter.create(loc, adaptor.getLhs(), - adaptor.getRhs()); + Value result = spirv::IAddCarryOp::create(rewriter, loc, adaptor.getLhs(), + adaptor.getRhs()); - Value sumResult = rewriter.create( - loc, result, llvm::ArrayRef(0)); - Value carryValue = rewriter.create( - loc, result, llvm::ArrayRef(1)); + Value sumResult = spirv::CompositeExtractOp::create(rewriter, loc, result, + llvm::ArrayRef(0)); + Value carryValue = spirv::CompositeExtractOp::create(rewriter, loc, result, + llvm::ArrayRef(1)); // Convert the carry value to boolean. Value one = spirv::ConstantOp::getOne(dstElemTy, loc, rewriter); - Value carryResult = rewriter.create(loc, carryValue, one); + Value carryResult = spirv::IEqualOp::create(rewriter, loc, carryValue, one); rewriter.replaceOp(op, {sumResult, carryResult}); return success(); @@ -1125,12 +1126,12 @@ class MulIExtendedOpPattern final : public OpConversionPattern { ConversionPatternRewriter &rewriter) const override { Location loc = op->getLoc(); Value result = - rewriter.create(loc, adaptor.getLhs(), adaptor.getRhs()); + SPIRVMulOp::create(rewriter, loc, adaptor.getLhs(), adaptor.getRhs()); - Value low = rewriter.create(loc, result, - llvm::ArrayRef(0)); - Value high = rewriter.create(loc, result, - llvm::ArrayRef(1)); + Value low = spirv::CompositeExtractOp::create(rewriter, loc, result, + llvm::ArrayRef(0)); + Value high = spirv::CompositeExtractOp::create(rewriter, loc, result, + llvm::ArrayRef(1)); rewriter.replaceOp(op, {low, high}); return success(); @@ -1183,20 +1184,20 @@ class MinimumMaximumFOpPattern final : public OpConversionPattern { Location loc = op.getLoc(); Value spirvOp = - rewriter.create(loc, dstType, adaptor.getOperands()); + SPIRVOp::create(rewriter, loc, dstType, adaptor.getOperands()); if (bitEnumContainsAll(op.getFastmath(), arith::FastMathFlags::nnan)) { rewriter.replaceOp(op, spirvOp); return success(); } - Value lhsIsNan = rewriter.create(loc, adaptor.getLhs()); - Value rhsIsNan = rewriter.create(loc, adaptor.getRhs()); + Value lhsIsNan = spirv::IsNanOp::create(rewriter, loc, adaptor.getLhs()); + Value rhsIsNan = spirv::IsNanOp::create(rewriter, loc, adaptor.getRhs()); - Value select1 = rewriter.create(loc, dstType, lhsIsNan, - adaptor.getLhs(), spirvOp); - Value select2 = rewriter.create(loc, dstType, rhsIsNan, - adaptor.getRhs(), select1); + Value select1 = spirv::SelectOp::create(rewriter, loc, dstType, lhsIsNan, + adaptor.getLhs(), spirvOp); + Value select2 = spirv::SelectOp::create(rewriter, loc, dstType, rhsIsNan, + adaptor.getRhs(), select1); rewriter.replaceOp(op, select2); return success(); @@ -1237,7 +1238,7 @@ class MinNumMaxNumFOpPattern final : public OpConversionPattern { Location loc = op.getLoc(); Value spirvOp = - rewriter.create(loc, dstType, adaptor.getOperands()); + SPIRVOp::create(rewriter, loc, dstType, adaptor.getOperands()); if (!shouldInsertNanGuards() || bitEnumContainsAll(op.getFastmath(), arith::FastMathFlags::nnan)) { @@ -1245,13 +1246,13 @@ class MinNumMaxNumFOpPattern final : public OpConversionPattern { return success(); } - Value lhsIsNan = rewriter.create(loc, adaptor.getLhs()); - Value rhsIsNan = rewriter.create(loc, adaptor.getRhs()); + Value lhsIsNan = spirv::IsNanOp::create(rewriter, loc, adaptor.getLhs()); + Value rhsIsNan = spirv::IsNanOp::create(rewriter, loc, adaptor.getRhs()); - Value select1 = rewriter.create(loc, dstType, lhsIsNan, - adaptor.getRhs(), spirvOp); - Value select2 = rewriter.create(loc, dstType, rhsIsNan, - adaptor.getLhs(), select1); + Value select1 = spirv::SelectOp::create(rewriter, loc, dstType, lhsIsNan, + adaptor.getRhs(), spirvOp); + Value select2 = spirv::SelectOp::create(rewriter, loc, dstType, rhsIsNan, + adaptor.getLhs(), select1); rewriter.replaceOp(op, select2); return success(); diff --git a/mlir/lib/Conversion/ArmNeon2dToIntr/ArmNeon2dToIntr.cpp b/mlir/lib/Conversion/ArmNeon2dToIntr/ArmNeon2dToIntr.cpp index 9c6de938a7108..1510b0b16b07d 100644 --- a/mlir/lib/Conversion/ArmNeon2dToIntr/ArmNeon2dToIntr.cpp +++ b/mlir/lib/Conversion/ArmNeon2dToIntr/ArmNeon2dToIntr.cpp @@ -41,11 +41,11 @@ class Sdot2dLoweringPattern : public OpRewritePattern { Value c2d = op.getC(); Location loc = op.getLoc(); Value b1d = - rewriter.create(loc, flattenedVectorType, b2d); + vector::ShapeCastOp::create(rewriter, loc, flattenedVectorType, b2d); Value c1d = - rewriter.create(loc, flattenedVectorType, c2d); - Value newOp = rewriter.create(loc, op.getRes().getType(), op.getA(), - b1d, c1d); + vector::ShapeCastOp::create(rewriter, loc, flattenedVectorType, c2d); + Value newOp = SdotOp::create(rewriter, loc, op.getRes().getType(), + op.getA(), b1d, c1d); rewriter.replaceOp(op, {newOp}); return success(); } diff --git a/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp b/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp index 21ea444e31821..9bc3fa3473398 100644 --- a/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp +++ b/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp @@ -45,38 +45,38 @@ static Operation *createLoadTileSliceIntrinsic( if (layout == arm_sme::TileSliceLayout::Horizontal) { switch (type) { case arm_sme::ArmSMETileType::ZAB: - return rewriter.create( - loc, maskOp, ptr, tileId, tileSliceI32); + return arm_sme::aarch64_sme_ld1b_horiz::create(rewriter, loc, maskOp, ptr, + tileId, tileSliceI32); case arm_sme::ArmSMETileType::ZAH: - return rewriter.create( - loc, maskOp, ptr, tileId, tileSliceI32); + return arm_sme::aarch64_sme_ld1h_horiz::create(rewriter, loc, maskOp, ptr, + tileId, tileSliceI32); case arm_sme::ArmSMETileType::ZAS: - return rewriter.create( - loc, maskOp, ptr, tileId, tileSliceI32); + return arm_sme::aarch64_sme_ld1w_horiz::create(rewriter, loc, maskOp, ptr, + tileId, tileSliceI32); case arm_sme::ArmSMETileType::ZAD: - return rewriter.create( - loc, maskOp, ptr, tileId, tileSliceI32); + return arm_sme::aarch64_sme_ld1d_horiz::create(rewriter, loc, maskOp, ptr, + tileId, tileSliceI32); case arm_sme::ArmSMETileType::ZAQ: - return rewriter.create( - loc, maskOp, ptr, tileId, tileSliceI32); + return arm_sme::aarch64_sme_ld1q_horiz::create(rewriter, loc, maskOp, ptr, + tileId, tileSliceI32); } } else { switch (type) { case arm_sme::ArmSMETileType::ZAB: - return rewriter.create( - loc, maskOp, ptr, tileId, tileSliceI32); + return arm_sme::aarch64_sme_ld1b_vert::create(rewriter, loc, maskOp, ptr, + tileId, tileSliceI32); case arm_sme::ArmSMETileType::ZAH: - return rewriter.create( - loc, maskOp, ptr, tileId, tileSliceI32); + return arm_sme::aarch64_sme_ld1h_vert::create(rewriter, loc, maskOp, ptr, + tileId, tileSliceI32); case arm_sme::ArmSMETileType::ZAS: - return rewriter.create( - loc, maskOp, ptr, tileId, tileSliceI32); + return arm_sme::aarch64_sme_ld1w_vert::create(rewriter, loc, maskOp, ptr, + tileId, tileSliceI32); case arm_sme::ArmSMETileType::ZAD: - return rewriter.create( - loc, maskOp, ptr, tileId, tileSliceI32); + return arm_sme::aarch64_sme_ld1d_vert::create(rewriter, loc, maskOp, ptr, + tileId, tileSliceI32); case arm_sme::ArmSMETileType::ZAQ: - return rewriter.create( - loc, maskOp, ptr, tileId, tileSliceI32); + return arm_sme::aarch64_sme_ld1q_vert::create(rewriter, loc, maskOp, ptr, + tileId, tileSliceI32); break; } } @@ -91,38 +91,38 @@ static Operation *createStoreTileSliceIntrinsic( if (layout == arm_sme::TileSliceLayout::Horizontal) { switch (type) { case arm_sme::ArmSMETileType::ZAB: - return rewriter.create( - loc, maskOp, ptr, tileId, tileSliceI32); + return arm_sme::aarch64_sme_st1b_horiz::create(rewriter, loc, maskOp, ptr, + tileId, tileSliceI32); case arm_sme::ArmSMETileType::ZAH: - return rewriter.create( - loc, maskOp, ptr, tileId, tileSliceI32); + return arm_sme::aarch64_sme_st1h_horiz::create(rewriter, loc, maskOp, ptr, + tileId, tileSliceI32); case arm_sme::ArmSMETileType::ZAS: - return rewriter.create( - loc, maskOp, ptr, tileId, tileSliceI32); + return arm_sme::aarch64_sme_st1w_horiz::create(rewriter, loc, maskOp, ptr, + tileId, tileSliceI32); case arm_sme::ArmSMETileType::ZAD: - return rewriter.create( - loc, maskOp, ptr, tileId, tileSliceI32); + return arm_sme::aarch64_sme_st1d_horiz::create(rewriter, loc, maskOp, ptr, + tileId, tileSliceI32); case arm_sme::ArmSMETileType::ZAQ: - return rewriter.create( - loc, maskOp, ptr, tileId, tileSliceI32); + return arm_sme::aarch64_sme_st1q_horiz::create(rewriter, loc, maskOp, ptr, + tileId, tileSliceI32); } } else { switch (type) { case arm_sme::ArmSMETileType::ZAB: - return rewriter.create( - loc, maskOp, ptr, tileId, tileSliceI32); + return arm_sme::aarch64_sme_st1b_vert::create(rewriter, loc, maskOp, ptr, + tileId, tileSliceI32); case arm_sme::ArmSMETileType::ZAH: - return rewriter.create( - loc, maskOp, ptr, tileId, tileSliceI32); + return arm_sme::aarch64_sme_st1h_vert::create(rewriter, loc, maskOp, ptr, + tileId, tileSliceI32); case arm_sme::ArmSMETileType::ZAS: - return rewriter.create( - loc, maskOp, ptr, tileId, tileSliceI32); + return arm_sme::aarch64_sme_st1w_vert::create(rewriter, loc, maskOp, ptr, + tileId, tileSliceI32); case arm_sme::ArmSMETileType::ZAD: - return rewriter.create( - loc, maskOp, ptr, tileId, tileSliceI32); + return arm_sme::aarch64_sme_st1d_vert::create(rewriter, loc, maskOp, ptr, + tileId, tileSliceI32); case arm_sme::ArmSMETileType::ZAQ: - return rewriter.create( - loc, maskOp, ptr, tileId, tileSliceI32); + return arm_sme::aarch64_sme_st1q_vert::create(rewriter, loc, maskOp, ptr, + tileId, tileSliceI32); } } llvm_unreachable("unknown type in createStoreTileSliceIntrinsic"); @@ -146,16 +146,16 @@ createAllocaForTile(RewriterBase &rewriter, Location loc, // Move to the first operation in the function. rewriter.setInsertionPointToStart(&func.getBlocks().front()); // Create an alloca matching the tile size of the `tileOp`. - auto vscale = rewriter.create(loc); + auto vscale = vector::VectorScaleOp::create(rewriter, loc); auto tileElementType = tileOp.getTileType().getElementType(); auto memrefType = MemRefType::get( {ShapedType::kDynamic, ShapedType::kDynamic}, tileElementType); unsigned minElements = arm_sme::getSMETileSliceMinNumElts(tileElementType); auto minElementsOp = - rewriter.create(loc, minElements); - auto vectorLen = rewriter.create(loc, vscale, minElementsOp); - auto alloca = rewriter.create( - loc, memrefType, ValueRange{vectorLen, vectorLen}); + arith::ConstantIndexOp::create(rewriter, loc, minElements); + auto vectorLen = arith::MulIOp::create(rewriter, loc, vscale, minElementsOp); + auto alloca = memref::AllocaOp::create(rewriter, loc, memrefType, + ValueRange{vectorLen, vectorLen}); return alloca; } @@ -293,10 +293,10 @@ struct ConvertArmSMESpillsAndFillsToLLVM : public ConvertToLLVMPattern { Value tileMemory, Value sliceIndex) const { auto llvmType = getTypeConverter()->convertType(tileMemory.getType()); auto descriptor = - rewriter.create(loc, llvmType, tileMemory); - auto zero = rewriter.create(loc, 0, /*width=*/64); - auto sliceIndexI64 = rewriter.create( - loc, rewriter.getI64Type(), sliceIndex); + UnrealizedConversionCastOp::create(rewriter, loc, llvmType, tileMemory); + auto zero = arith::ConstantIntOp::create(rewriter, loc, 0, /*width=*/64); + auto sliceIndexI64 = arith::IndexCastOp::create( + rewriter, loc, rewriter.getI64Type(), sliceIndex); return getStridedElementPtr( static_cast(rewriter), loc, llvm::cast(tileMemory.getType()), descriptor.getResult(0), @@ -309,28 +309,29 @@ struct ConvertArmSMESpillsAndFillsToLLVM : public ConvertToLLVMPattern { arm_sme::ArmSMETileType tileType, VectorType sliceType, IntegerAttr tileId, Value sliceIndex) const { // Cast the slice index to an i32. - auto sliceIndexI32 = rewriter.create( - loc, rewriter.getI32Type(), sliceIndex); + auto sliceIndexI32 = arith::IndexCastOp::create( + rewriter, loc, rewriter.getI32Type(), sliceIndex); // Create an all-true predicate for the slice. auto predicateType = sliceType.clone(rewriter.getI1Type()); - auto allTruePredicate = rewriter.create( - loc, DenseElementsAttr::get(predicateType, true)); + auto allTruePredicate = arith::ConstantOp::create( + rewriter, loc, DenseElementsAttr::get(predicateType, true)); // Create padding vector (never used due to all-true predicate). - auto padVector = rewriter.create(loc, sliceType); + auto padVector = LLVM::PoisonOp::create(rewriter, loc, sliceType); // Get a pointer to the current slice. auto slicePtr = getInMemoryTileSlicePtr(rewriter, loc, tileAlloca, sliceIndex); // Read the value of the current slice from ZA. - auto currentTileSlice = rewriter.create( - loc, sliceType, padVector, allTruePredicate, tileId, sliceIndexI32); + auto currentTileSlice = arm_sme::aarch64_sme_read_horiz::create( + rewriter, loc, sliceType, padVector, allTruePredicate, tileId, + sliceIndexI32); // Load the new tile slice back from memory into ZA. createLoadTileSliceIntrinsic( rewriter, loc, tileType, arm_sme::TileSliceLayout::Horizontal, allTruePredicate, slicePtr, tileId, sliceIndexI32); // Store the current tile slice to memory. - auto zero = rewriter.create(loc, 0); - rewriter.create(loc, currentTileSlice, tileAlloca, - ValueRange{sliceIndex, zero}); + auto zero = arith::ConstantIndexOp::create(rewriter, loc, 0); + vector::StoreOp::create(rewriter, loc, currentTileSlice, tileAlloca, + ValueRange{sliceIndex, zero}); } /// Emits a full in-place swap of the contents of a tile in ZA and a @@ -341,12 +342,14 @@ struct ConvertArmSMESpillsAndFillsToLLVM : public ConvertToLLVMPattern { RewriterBase::InsertionGuard guard(rewriter); // Create an scf.for over all tile slices. auto minNumElts = - rewriter.create(loc, sliceType.getDimSize(0)); - auto lowerBound = rewriter.create(loc, 0); - auto upperBound = rewriter.create( - loc, minNumElts, rewriter.create(loc)); - auto step = rewriter.create(loc, 1); - auto forOp = rewriter.create(loc, lowerBound, upperBound, step); + arith::ConstantIndexOp::create(rewriter, loc, sliceType.getDimSize(0)); + auto lowerBound = arith::ConstantIndexOp::create(rewriter, loc, 0); + auto upperBound = + arith::MulIOp::create(rewriter, loc, minNumElts, + vector::VectorScaleOp::create(rewriter, loc)); + auto step = arith::ConstantIndexOp::create(rewriter, loc, 1); + auto forOp = + scf::ForOp::create(rewriter, loc, lowerBound, upperBound, step); // Emit a swap for each tile slice. rewriter.setInsertionPointToStart(forOp.getBody()); auto sliceIndex = forOp.getInductionVar(); @@ -479,8 +482,8 @@ struct ZeroOpConversion : public ConvertArmSMEOpToLLVMPattern { // // This holds for all tile sizes. int32_t zeroMask = baseMaskForSize << int32_t(tileId.getInt()); - rewriter.create( - loc, rewriter.getI32IntegerAttr(zeroMask)); + arm_sme::aarch64_sme_zero::create(rewriter, loc, + rewriter.getI32IntegerAttr(zeroMask)); // Create a placeholder op to preserve dataflow. // Note: Place the `get_tile` op at the start of the block. This ensures @@ -513,8 +516,8 @@ struct LoadTileSliceConversion auto tileSlice = loadTileSliceOp.getTileSliceIndex(); // Cast tile slice to i32 for intrinsic. - auto tileSliceI32 = rewriter.create( - loc, rewriter.getI32Type(), tileSlice); + auto tileSliceI32 = arith::IndexCastUIOp::create( + rewriter, loc, rewriter.getI32Type(), tileSlice); // Create all active predicate mask. auto maskOp = loadTileSliceOp.getMask(); @@ -559,8 +562,8 @@ struct StoreTileSliceConversion auto tileSlice = storeTileSliceOp.getTileSliceIndex(); // Cast tile slice to i32 for intrinsic. - auto tileSliceI32 = rewriter.create( - loc, rewriter.getI32Type(), tileSlice); + auto tileSliceI32 = arith::IndexCastUIOp::create( + rewriter, loc, rewriter.getI32Type(), tileSlice); auto maskOp = storeTileSliceOp.getMask(); @@ -595,28 +598,28 @@ struct InsertTileSliceConversion auto tileSlice = insertTileSliceOp.getTileSliceIndex(); // Cast tile slice from index to i32 for intrinsic. - auto tileSliceI32 = rewriter.create( - loc, rewriter.getI32Type(), tileSlice); + auto tileSliceI32 = arith::IndexCastUIOp::create( + rewriter, loc, rewriter.getI32Type(), tileSlice); // Create all active predicate mask. - auto one = rewriter.create( - loc, rewriter.getI1Type(), + auto one = arith::ConstantOp::create( + rewriter, loc, rewriter.getI1Type(), rewriter.getIntegerAttr(rewriter.getI1Type(), 1)); auto predTy = VectorType::get(tileType.getShape()[0], rewriter.getI1Type(), /*scalableDims=*/{true}); - auto allActiveMask = rewriter.create(loc, predTy, one); + auto allActiveMask = vector::SplatOp::create(rewriter, loc, predTy, one); // Create 'arm_sme.intr.write.(horiz|vert)' to write vector to tile slice. switch (insertTileSliceOp.getLayout()) { case arm_sme::TileSliceLayout::Horizontal: - rewriter.create( - loc, tileId, tileSliceI32, allActiveMask, - insertTileSliceOp.getVector()); + arm_sme::aarch64_sme_write_horiz::create(rewriter, loc, tileId, + tileSliceI32, allActiveMask, + insertTileSliceOp.getVector()); break; case arm_sme::TileSliceLayout::Vertical: - rewriter.create( - loc, tileId, tileSliceI32, allActiveMask, - insertTileSliceOp.getVector()); + arm_sme::aarch64_sme_write_vert::create(rewriter, loc, tileId, + tileSliceI32, allActiveMask, + insertTileSliceOp.getVector()); break; } @@ -646,16 +649,16 @@ struct ExtractTileSliceConversion // Create an 'all true' predicate for the tile slice. auto predicateType = sliceType.cloneWith({}, rewriter.getI1Type()); - auto allTruePredicate = rewriter.create( - loc, DenseElementsAttr::get(predicateType, true)); + auto allTruePredicate = arith::ConstantOp::create( + rewriter, loc, DenseElementsAttr::get(predicateType, true)); // Zero destination/fallback for tile slice extraction. - auto zeroVector = rewriter.create( - loc, sliceType, rewriter.getZeroAttr(sliceType)); + auto zeroVector = arith::ConstantOp::create( + rewriter, loc, sliceType, rewriter.getZeroAttr(sliceType)); // Cast tile slice from index to i32 for intrinsic. - auto sliceIndexI32 = rewriter.create( - loc, rewriter.getI32Type(), sliceIndex); + auto sliceIndexI32 = arith::IndexCastOp::create( + rewriter, loc, rewriter.getI32Type(), sliceIndex); // Create 'arm_sme.intr.read.(horiz|vert)' to extract the tile slice. switch (extractTileSlice.getLayout()) { @@ -743,7 +746,7 @@ struct OuterProductOpConversion Value acc = outerProductOp.getAcc(); if (!acc) { // Initalize accumulator with zero. - auto zero = rewriter.create(loc, resultVectorType); + auto zero = arm_sme::ZeroOp::create(rewriter, loc, resultVectorType); zero.setTileId(tileId); acc = zero; } @@ -754,16 +757,16 @@ struct OuterProductOpConversion if (!lhsMask || !rhsMask) { auto predTy = outerProductOp.getLhsType().cloneWith({}, rewriter.getI1Type()); - Value allActiveMask = rewriter.create( - loc, DenseElementsAttr::get(predTy, true)); + Value allActiveMask = arith::ConstantOp::create( + rewriter, loc, DenseElementsAttr::get(predTy, true)); lhsMask = allActiveMask; rhsMask = allActiveMask; } // Create 'arm_sme.intr.mopa' outer product intrinsic. - rewriter.create(loc, tileId, lhsMask, rhsMask, - outerProductOp.getLhs(), - outerProductOp.getRhs()); + arm_sme::aarch64_sme_mopa::create(rewriter, loc, tileId, lhsMask, rhsMask, + outerProductOp.getLhs(), + outerProductOp.getRhs()); // The outerproduct intrinsics have no result, replace // 'arm_sme.outerproduct' with the input tile to preserve dataflow. @@ -792,7 +795,7 @@ struct OuterProductWideningOpConversion Value acc = op.getAcc(); if (!acc) { // Initalize accumulator with zero. - auto zero = rewriter.create(loc, op.getResultType()); + auto zero = arm_sme::ZeroOp::create(rewriter, loc, op.getResultType()); zero.setTileId(tileId); acc = zero; } @@ -801,14 +804,14 @@ struct OuterProductWideningOpConversion Value rhsMask = op.getRhsMask(); if (!lhsMask || !rhsMask) { auto predTy = op.getLhsType().cloneWith({}, rewriter.getI1Type()); - Value allActiveMask = rewriter.create( - loc, DenseElementsAttr::get(predTy, true)); + Value allActiveMask = arith::ConstantOp::create( + rewriter, loc, DenseElementsAttr::get(predTy, true)); lhsMask = allActiveMask; rhsMask = allActiveMask; } - rewriter.create( - loc, tileId, lhsMask, rhsMask, adaptor.getLhs(), adaptor.getRhs()); + OuterProductWideningIntrOp::create(rewriter, loc, tileId, lhsMask, rhsMask, + adaptor.getLhs(), adaptor.getRhs()); // The outerproduct intrinsics have no result, replace // 'arm_sme.outerproduct' with the input tile to preserve dataflow. @@ -843,13 +846,13 @@ struct StreamingVLOpConversion auto *intrOp = [&]() -> Operation * { switch (streamingVlOp.getTypeSize()) { case arm_sme::TypeSize::Byte: - return rewriter.create(loc, i64Type); + return arm_sme::aarch64_sme_cntsb::create(rewriter, loc, i64Type); case arm_sme::TypeSize::Half: - return rewriter.create(loc, i64Type); + return arm_sme::aarch64_sme_cntsh::create(rewriter, loc, i64Type); case arm_sme::TypeSize::Word: - return rewriter.create(loc, i64Type); + return arm_sme::aarch64_sme_cntsw::create(rewriter, loc, i64Type); case arm_sme::TypeSize::Double: - return rewriter.create(loc, i64Type); + return arm_sme::aarch64_sme_cntsd::create(rewriter, loc, i64Type); } llvm_unreachable("unknown type size in StreamingVLOpConversion"); }(); @@ -872,8 +875,8 @@ static void mergeConsecutiveTileZerosInBlock(Block *block) { if (zeroOpsToMerge.size() <= 1) return; IRRewriter rewriter(zeroOpsToMerge.front()); - rewriter.create( - zeroOpsToMerge.front().getLoc(), + arm_sme::aarch64_sme_zero::create( + rewriter, zeroOpsToMerge.front().getLoc(), rewriter.getI32IntegerAttr(mergedZeroMask)); for (auto zeroOp : zeroOpsToMerge) rewriter.eraseOp(zeroOp); diff --git a/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp b/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp index 458628c29c6ac..9a37b30c14813 100644 --- a/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp +++ b/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp @@ -39,7 +39,7 @@ SmallVector getMemrefIndices(ValueRange indices, unsigned rank, auto tileSliceOffset = tileSliceIndex; auto baseIndexPlusTileSliceOffset = - rewriter.create(loc, indices[0], tileSliceOffset); + arith::AddIOp::create(rewriter, loc, indices[0], tileSliceOffset); outIndices.push_back(baseIndexPlusTileSliceOffset); outIndices.push_back(indices[1]); @@ -59,10 +59,11 @@ FailureOr createLoadStoreForOverTileSlices( if (memrefIndices.size() != 2) return rewriter.notifyMatchFailure(loc, "invalid number of indices"); - auto minTileSlices = rewriter.create( - loc, arm_sme::getSMETileSliceMinNumElts(tileType.getElementType())); + auto minTileSlices = arith::ConstantIndexOp::create( + rewriter, loc, + arm_sme::getSMETileSliceMinNumElts(tileType.getElementType())); auto vscale = - rewriter.create(loc, rewriter.getIndexType()); + vector::VectorScaleOp::create(rewriter, loc, rewriter.getIndexType()); auto predicateType = VectorType::get(tileType.getDimSize(1), rewriter.getI1Type(), true); @@ -70,7 +71,7 @@ FailureOr createLoadStoreForOverTileSlices( // elements in a vector of SVL bits for a given element type (SVL_B, // SVL_H, ..., SVL_Q). auto numTileSlices = - rewriter.create(loc, minTileSlices, vscale); + arith::MulIOp::create(rewriter, loc, minTileSlices, vscale); Value predicate; Value upperBound; @@ -82,30 +83,30 @@ FailureOr createLoadStoreForOverTileSlices( // The upper bound of the loop must be clamped at `numTileSlices` as // `vector.create_mask` allows operands to be greater than the size of a // dimension. - auto numRowI64 = rewriter.create( - loc, rewriter.getI64Type(), maskDim0); - auto numTileSlicesI64 = rewriter.create( - loc, rewriter.getI64Type(), numTileSlices); + auto numRowI64 = arith::IndexCastOp::create( + rewriter, loc, rewriter.getI64Type(), maskDim0); + auto numTileSlicesI64 = arith::IndexCastOp::create( + rewriter, loc, rewriter.getI64Type(), numTileSlices); auto upperBoundI64 = - rewriter.create(loc, numRowI64, numTileSlicesI64); - upperBound = rewriter.create( - loc, rewriter.getIndexType(), upperBoundI64); + arith::MinSIOp::create(rewriter, loc, numRowI64, numTileSlicesI64); + upperBound = arith::IndexCastOp::create( + rewriter, loc, rewriter.getIndexType(), upperBoundI64); predicate = - rewriter.create(loc, predicateType, maskDim1); + vector::CreateMaskOp::create(rewriter, loc, predicateType, maskDim1); } else { upperBound = numTileSlices; // No mask. Create an 'all true' predicate for the tile slice. - predicate = rewriter.create( - loc, DenseElementsAttr::get(predicateType, true)); + predicate = arith::ConstantOp::create( + rewriter, loc, DenseElementsAttr::get(predicateType, true)); } bool hasCarriedArgs = bool(initTile); - auto lowerBound = rewriter.create(loc, 0); - auto step = rewriter.create(loc, 1); - auto forOp = rewriter.create(loc, lowerBound, upperBound, step, - hasCarriedArgs ? ValueRange{initTile} - : ValueRange{}); + auto lowerBound = arith::ConstantIndexOp::create(rewriter, loc, 0); + auto step = arith::ConstantIndexOp::create(rewriter, loc, 1); + auto forOp = + scf::ForOp::create(rewriter, loc, lowerBound, upperBound, step, + hasCarriedArgs ? ValueRange{initTile} : ValueRange{}); rewriter.setInsertionPointToStart(forOp.getBody()); Value tileSliceIndex = forOp.getInductionVar(); @@ -118,7 +119,7 @@ FailureOr createLoadStoreForOverTileSlices( assert(bool(nextTile) == hasCarriedArgs); if (nextTile) - rewriter.create(loc, nextTile); + scf::YieldOp::create(rewriter, loc, nextTile); return forOp; } @@ -194,9 +195,9 @@ struct TileLoadOpConversion : public OpRewritePattern { // Initialize tile with zero to satisfy padding. Inactive cols will be // zeroed anyway since the loads use zeroing predication. For inactive // rows however, no load will occur so these need to be zeroed. - initTile = rewriter.create(loc, tileType); + initTile = arm_sme::ZeroOp::create(rewriter, loc, tileType); } else { - initTile = rewriter.create(loc, tileType); + initTile = arm_sme::GetTileOp::create(rewriter, loc, tileType); } // Create a loop to load the active tile slices from memory. @@ -207,9 +208,10 @@ struct TileLoadOpConversion : public OpRewritePattern { Value currentTile) -> Value { // Create 'arm_sme.load_tile_slice' to load tile slice from memory // into tile. - return rewriter.create( - loc, tileType, tileLoadOp.getBase(), predicate, currentTile, - memrefIndices, tileSliceIndex, tileLoadOp.getLayout()); + return arm_sme::LoadTileSliceOp::create( + rewriter, loc, tileType, tileLoadOp.getBase(), predicate, + currentTile, memrefIndices, tileSliceIndex, + tileLoadOp.getLayout()); }); if (failed(forOp)) @@ -283,22 +285,22 @@ struct TileLoadOpWithMaskAndPadNonZeroConversion auto numRows = createMaskOp.getOperands()[0]; auto numCols = createMaskOp.getOperands()[1]; - auto numColsI32 = rewriter.create( - loc, rewriter.getI32Type(), numCols); + auto numColsI32 = arith::IndexCastUIOp::create( + rewriter, loc, rewriter.getI32Type(), numCols); - auto initTile = rewriter.create(loc, tileType); + auto initTile = arm_sme::GetTileOp::create(rewriter, loc, tileType); // Create a loop that loads each ZA tile slice from memory. - auto step = rewriter.create(loc, 1); - auto minTileSlices = rewriter.create( - loc, arm_sme::getSMETileSliceMinNumElts(tileElementType)); + auto step = arith::ConstantIndexOp::create(rewriter, loc, 1); + auto minTileSlices = arith::ConstantIndexOp::create( + rewriter, loc, arm_sme::getSMETileSliceMinNumElts(tileElementType)); auto vscale = - rewriter.create(loc, rewriter.getIndexType()); - auto lowerBound = rewriter.create(loc, 0); + vector::VectorScaleOp::create(rewriter, loc, rewriter.getIndexType()); + auto lowerBound = arith::ConstantIndexOp::create(rewriter, loc, 0); auto numTileSlices = - rewriter.create(loc, minTileSlices, vscale); - auto forOp = rewriter.create(loc, lowerBound, numTileSlices, - step, ValueRange{initTile}); + arith::MulIOp::create(rewriter, loc, minTileSlices, vscale); + auto forOp = scf::ForOp::create(rewriter, loc, lowerBound, numTileSlices, + step, ValueRange{initTile}); rewriter.setInsertionPointToStart(forOp.getBody()); @@ -306,17 +308,18 @@ struct TileLoadOpWithMaskAndPadNonZeroConversion auto currentTile = forOp.getRegionIterArg(0); // Combine masks. - auto rowIsActive = rewriter.create( - loc, arith::CmpIPredicate::ult, tileSliceIndex, numRows); - auto rowIsActiveI32 = rewriter.create( - loc, rewriter.getI32Type(), rowIsActive); - auto mask = rewriter.create(loc, rowIsActiveI32, numColsI32); - auto maskIndex = - rewriter.create(loc, rewriter.getIndexType(), mask); + auto rowIsActive = arith::CmpIOp::create( + rewriter, loc, arith::CmpIPredicate::ult, tileSliceIndex, numRows); + auto rowIsActiveI32 = arith::ExtSIOp::create( + rewriter, loc, rewriter.getI32Type(), rowIsActive); + auto mask = + arith::AndIOp::create(rewriter, loc, rowIsActiveI32, numColsI32); + auto maskIndex = arith::IndexCastOp::create(rewriter, loc, + rewriter.getIndexType(), mask); auto predicateType = VectorType::get(tileType.getDimSize(1), rewriter.getI1Type(), true); - auto maskOp1D = rewriter.create( - loc, predicateType, maskIndex.getResult()); + auto maskOp1D = vector::CreateMaskOp::create(rewriter, loc, predicateType, + maskIndex.getResult()); auto memrefIndices = getMemrefIndices( tileLoadOp.getIndices(), tileLoadOp.getMemRefType().getRank(), @@ -324,17 +327,18 @@ struct TileLoadOpWithMaskAndPadNonZeroConversion // Splat pad into 1-D vector matching type of tile slice. VectorType tileSliceType = VectorType::Builder(tileType).dropDim(0); - auto pad1DOp = rewriter.create(loc, tileSliceType, padOp); + auto pad1DOp = vector::SplatOp::create(rewriter, loc, tileSliceType, padOp); - auto loadSlice = rewriter.create( - loc, tileSliceType, tileLoadOp.getBase(), memrefIndices, maskOp1D, - /*passthru=*/pad1DOp); + auto loadSlice = vector::MaskedLoadOp::create(rewriter, loc, tileSliceType, + tileLoadOp.getBase(), + memrefIndices, maskOp1D, + /*passthru=*/pad1DOp); // Create 'arm_sme.insert_tile_slice' to insert slice into tile. - auto insertSlice = rewriter.create( - loc, tileType, loadSlice->getResult(0), currentTile, tileSliceIndex, - tileLoadOp.getLayout()); - rewriter.create(loc, insertSlice.getResult()); + auto insertSlice = arm_sme::InsertTileSliceOp::create( + rewriter, loc, tileType, loadSlice->getResult(0), currentTile, + tileSliceIndex, tileLoadOp.getLayout()); + scf::YieldOp::create(rewriter, loc, insertSlice.getResult()); rewriter.setInsertionPointAfter(forOp); diff --git a/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp b/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp index 94f7caa315cf7..79e1683b4e2cf 100644 --- a/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp +++ b/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp @@ -203,7 +203,7 @@ static void addAsyncRuntimeApiDeclarations(ModuleOp module) { auto addFuncDecl = [&](StringRef name, FunctionType type) { if (module.lookupSymbol(name)) return; - builder.create(name, type).setPrivate(); + func::FuncOp::create(builder, name, type).setPrivate(); }; MLIRContext *ctx = module.getContext(); @@ -254,15 +254,15 @@ static void addResumeFunction(ModuleOp module) { auto voidTy = LLVM::LLVMVoidType::get(ctx); Type ptrType = AsyncAPI::opaquePointerType(ctx); - auto resumeOp = moduleBuilder.create( - kResume, LLVM::LLVMFunctionType::get(voidTy, {ptrType})); + auto resumeOp = LLVM::LLVMFuncOp::create( + moduleBuilder, kResume, LLVM::LLVMFunctionType::get(voidTy, {ptrType})); resumeOp.setPrivate(); auto *block = resumeOp.addEntryBlock(moduleBuilder); auto blockBuilder = ImplicitLocOpBuilder::atBlockEnd(loc, block); - blockBuilder.create(resumeOp.getArgument(0)); - blockBuilder.create(ValueRange()); + LLVM::CoroResumeOp::create(blockBuilder, resumeOp.getArgument(0)); + LLVM::ReturnOp::create(blockBuilder, ValueRange()); } //===----------------------------------------------------------------------===// @@ -282,7 +282,8 @@ class AsyncRuntimeTypeConverter : public TypeConverter { // in patterns for other dialects. auto addUnrealizedCast = [](OpBuilder &builder, Type type, ValueRange inputs, Location loc) -> Value { - auto cast = builder.create(loc, type, inputs); + auto cast = + UnrealizedConversionCastOp::create(builder, loc, type, inputs); return cast.getResult(0); }; @@ -343,8 +344,8 @@ class CoroIdOpConversion : public AsyncOpConversionPattern { // Constants for initializing coroutine frame. auto constZero = - rewriter.create(loc, rewriter.getI32Type(), 0); - auto nullPtr = rewriter.create(loc, ptrType); + LLVM::ConstantOp::create(rewriter, loc, rewriter.getI32Type(), 0); + auto nullPtr = LLVM::ZeroOp::create(rewriter, loc, ptrType); // Get coroutine id: @llvm.coro.id. rewriter.replaceOpWithNewOp( @@ -372,33 +373,33 @@ class CoroBeginOpConversion : public AsyncOpConversionPattern { // Get coroutine frame size: @llvm.coro.size.i64. Value coroSize = - rewriter.create(loc, rewriter.getI64Type()); + LLVM::CoroSizeOp::create(rewriter, loc, rewriter.getI64Type()); // Get coroutine frame alignment: @llvm.coro.align.i64. Value coroAlign = - rewriter.create(loc, rewriter.getI64Type()); + LLVM::CoroAlignOp::create(rewriter, loc, rewriter.getI64Type()); // Round up the size to be multiple of the alignment. Since aligned_alloc // requires the size parameter be an integral multiple of the alignment // parameter. auto makeConstant = [&](uint64_t c) { - return rewriter.create(op->getLoc(), - rewriter.getI64Type(), c); + return LLVM::ConstantOp::create(rewriter, op->getLoc(), + rewriter.getI64Type(), c); }; - coroSize = rewriter.create(op->getLoc(), coroSize, coroAlign); + coroSize = LLVM::AddOp::create(rewriter, op->getLoc(), coroSize, coroAlign); coroSize = - rewriter.create(op->getLoc(), coroSize, makeConstant(1)); + LLVM::SubOp::create(rewriter, op->getLoc(), coroSize, makeConstant(1)); Value negCoroAlign = - rewriter.create(op->getLoc(), makeConstant(0), coroAlign); + LLVM::SubOp::create(rewriter, op->getLoc(), makeConstant(0), coroAlign); coroSize = - rewriter.create(op->getLoc(), coroSize, negCoroAlign); + LLVM::AndOp::create(rewriter, op->getLoc(), coroSize, negCoroAlign); // Allocate memory for the coroutine frame. auto allocFuncOp = LLVM::lookupOrCreateAlignedAllocFn( rewriter, op->getParentOfType(), rewriter.getI64Type()); if (failed(allocFuncOp)) return failure(); - auto coroAlloc = rewriter.create( - loc, allocFuncOp.value(), ValueRange{coroAlign, coroSize}); + auto coroAlloc = LLVM::CallOp::create(rewriter, loc, allocFuncOp.value(), + ValueRange{coroAlign, coroSize}); // Begin a coroutine: @llvm.coro.begin. auto coroId = CoroBeginOpAdaptor(adaptor.getOperands()).getId(); @@ -427,7 +428,7 @@ class CoroFreeOpConversion : public AsyncOpConversionPattern { // Get a pointer to the coroutine frame memory: @llvm.coro.free. auto coroMem = - rewriter.create(loc, ptrType, adaptor.getOperands()); + LLVM::CoroFreeOp::create(rewriter, loc, ptrType, adaptor.getOperands()); // Free the memory. auto freeFuncOp = @@ -455,15 +456,15 @@ class CoroEndOpConversion : public OpConversionPattern { matchAndRewrite(CoroEndOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { // We are not in the block that is part of the unwind sequence. - auto constFalse = rewriter.create( - op->getLoc(), rewriter.getI1Type(), rewriter.getBoolAttr(false)); - auto noneToken = rewriter.create(op->getLoc()); + auto constFalse = + LLVM::ConstantOp::create(rewriter, op->getLoc(), rewriter.getI1Type(), + rewriter.getBoolAttr(false)); + auto noneToken = LLVM::NoneTokenOp::create(rewriter, op->getLoc()); // Mark the end of a coroutine: @llvm.coro.end. auto coroHdl = adaptor.getHandle(); - rewriter.create( - op->getLoc(), rewriter.getI1Type(), - ValueRange({coroHdl, constFalse, noneToken})); + LLVM::CoroEndOp::create(rewriter, op->getLoc(), rewriter.getI1Type(), + ValueRange({coroHdl, constFalse, noneToken})); rewriter.eraseOp(op); return success(); @@ -534,13 +535,13 @@ class CoroSuspendOpConversion : public OpConversionPattern { auto loc = op->getLoc(); // This is not a final suspension point. - auto constFalse = rewriter.create( - loc, rewriter.getI1Type(), rewriter.getBoolAttr(false)); + auto constFalse = LLVM::ConstantOp::create( + rewriter, loc, rewriter.getI1Type(), rewriter.getBoolAttr(false)); // Suspend a coroutine: @llvm.coro.suspend auto coroState = adaptor.getState(); - auto coroSuspend = rewriter.create( - loc, i8, ValueRange({coroState, constFalse})); + auto coroSuspend = LLVM::CoroSuspendOp::create( + rewriter, loc, i8, ValueRange({coroState, constFalse})); // Cast return code to i32. @@ -551,7 +552,7 @@ class CoroSuspendOpConversion : public OpConversionPattern { llvm::SmallVector caseDest = {op.getResumeDest(), op.getCleanupDest()}; rewriter.replaceOpWithNewOp( - op, rewriter.create(loc, i32, coroSuspend.getResult()), + op, LLVM::SExtOp::create(rewriter, loc, i32, coroSuspend.getResult()), /*defaultDestination=*/op.getSuspendDest(), /*defaultOperands=*/ValueRange(), /*caseValues=*/caseValues, @@ -602,11 +603,11 @@ class RuntimeCreateOpLowering : public ConvertOpToLLVMPattern { // %Size = getelementptr %T* null, int 1 // %SizeI = ptrtoint %T* %Size to i64 - auto nullPtr = rewriter.create(loc, storagePtrType); + auto nullPtr = LLVM::ZeroOp::create(rewriter, loc, storagePtrType); auto gep = - rewriter.create(loc, storagePtrType, storedType, - nullPtr, ArrayRef{1}); - return rewriter.create(loc, i64, gep); + LLVM::GEPOp::create(rewriter, loc, storagePtrType, storedType, + nullPtr, ArrayRef{1}); + return LLVM::PtrToIntOp::create(rewriter, loc, i64, gep); }; rewriter.replaceOpWithNewOp(op, kCreateValue, resultType, @@ -739,8 +740,8 @@ class RuntimeAwaitOpLowering : public OpConversionPattern { .Case([](Type) { return kAwaitValue; }) .Case([](Type) { return kAwaitGroup; }); - rewriter.create(op->getLoc(), apiFuncName, TypeRange(), - adaptor.getOperands()); + func::CallOp::create(rewriter, op->getLoc(), apiFuncName, TypeRange(), + adaptor.getOperands()); rewriter.eraseOp(op); return success(); @@ -772,13 +773,12 @@ class RuntimeAwaitAndResumeOpLowering // A pointer to coroutine resume intrinsic wrapper. addResumeFunction(op->getParentOfType()); - auto resumePtr = rewriter.create( - op->getLoc(), AsyncAPI::opaquePointerType(rewriter.getContext()), - kResume); + auto resumePtr = LLVM::AddressOfOp::create( + rewriter, op->getLoc(), + AsyncAPI::opaquePointerType(rewriter.getContext()), kResume); - rewriter.create( - op->getLoc(), apiFuncName, TypeRange(), - ValueRange({operand, handle, resumePtr.getRes()})); + func::CallOp::create(rewriter, op->getLoc(), apiFuncName, TypeRange(), + ValueRange({operand, handle, resumePtr.getRes()})); rewriter.eraseOp(op); return success(); @@ -801,9 +801,9 @@ class RuntimeResumeOpLowering ConversionPatternRewriter &rewriter) const override { // A pointer to coroutine resume intrinsic wrapper. addResumeFunction(op->getParentOfType()); - auto resumePtr = rewriter.create( - op->getLoc(), AsyncAPI::opaquePointerType(rewriter.getContext()), - kResume); + auto resumePtr = LLVM::AddressOfOp::create( + rewriter, op->getLoc(), + AsyncAPI::opaquePointerType(rewriter.getContext()), kResume); // Call async runtime API to execute a coroutine in the managed thread. auto coroHdl = adaptor.getHandle(); @@ -832,8 +832,8 @@ class RuntimeStoreOpLowering : public ConvertOpToLLVMPattern { // Get a pointer to the async value storage from the runtime. auto ptrType = AsyncAPI::opaquePointerType(rewriter.getContext()); auto storage = adaptor.getStorage(); - auto storagePtr = rewriter.create( - loc, kGetValueStorage, TypeRange(ptrType), storage); + auto storagePtr = func::CallOp::create(rewriter, loc, kGetValueStorage, + TypeRange(ptrType), storage); // Cast from i8* to the LLVM pointer type. auto valueType = op.getValue().getType(); @@ -845,7 +845,7 @@ class RuntimeStoreOpLowering : public ConvertOpToLLVMPattern { Value castedStoragePtr = storagePtr.getResult(0); // Store the yielded value into the async value storage. auto value = adaptor.getValue(); - rewriter.create(loc, value, castedStoragePtr); + LLVM::StoreOp::create(rewriter, loc, value, castedStoragePtr); // Erase the original runtime store operation. rewriter.eraseOp(op); @@ -872,8 +872,8 @@ class RuntimeLoadOpLowering : public ConvertOpToLLVMPattern { // Get a pointer to the async value storage from the runtime. auto ptrType = AsyncAPI::opaquePointerType(rewriter.getContext()); auto storage = adaptor.getStorage(); - auto storagePtr = rewriter.create( - loc, kGetValueStorage, TypeRange(ptrType), storage); + auto storagePtr = func::CallOp::create(rewriter, loc, kGetValueStorage, + TypeRange(ptrType), storage); // Cast from i8* to the LLVM pointer type. auto valueType = op.getResult().getType(); @@ -960,9 +960,9 @@ class RefCountingOpLowering : public OpConversionPattern { LogicalResult matchAndRewrite(RefCountingOp op, typename RefCountingOp::Adaptor adaptor, ConversionPatternRewriter &rewriter) const override { - auto count = rewriter.create( - op->getLoc(), rewriter.getI64Type(), - rewriter.getI64IntegerAttr(op.getCount())); + auto count = + arith::ConstantOp::create(rewriter, op->getLoc(), rewriter.getI64Type(), + rewriter.getI64IntegerAttr(op.getCount())); auto operand = adaptor.getOperand(); rewriter.replaceOpWithNewOp(op, TypeRange(), apiFunctionName, diff --git a/mlir/lib/Conversion/BufferizationToMemRef/BufferizationToMemRef.cpp b/mlir/lib/Conversion/BufferizationToMemRef/BufferizationToMemRef.cpp index b9991f36cdaaf..30a7170cf5c6a 100644 --- a/mlir/lib/Conversion/BufferizationToMemRef/BufferizationToMemRef.cpp +++ b/mlir/lib/Conversion/BufferizationToMemRef/BufferizationToMemRef.cpp @@ -47,26 +47,26 @@ struct CloneOpConversion : public OpConversionPattern { if (auto unrankedType = dyn_cast(type)) { // Constants - Value zero = rewriter.create(loc, 0); - Value one = rewriter.create(loc, 1); + Value zero = arith::ConstantIndexOp::create(rewriter, loc, 0); + Value one = arith::ConstantIndexOp::create(rewriter, loc, 1); // Dynamically evaluate the size and shape of the unranked memref - Value rank = rewriter.create(loc, op.getInput()); + Value rank = memref::RankOp::create(rewriter, loc, op.getInput()); MemRefType allocType = MemRefType::get({ShapedType::kDynamic}, rewriter.getIndexType()); - Value shape = rewriter.create(loc, allocType, rank); + Value shape = memref::AllocaOp::create(rewriter, loc, allocType, rank); // Create a loop to query dimension sizes, store them as a shape, and // compute the total size of the memref auto loopBody = [&](OpBuilder &builder, Location loc, Value i, ValueRange args) { auto acc = args.front(); - auto dim = rewriter.create(loc, op.getInput(), i); + auto dim = memref::DimOp::create(rewriter, loc, op.getInput(), i); - rewriter.create(loc, dim, shape, i); - acc = rewriter.create(loc, acc, dim); + memref::StoreOp::create(rewriter, loc, dim, shape, i); + acc = arith::MulIOp::create(rewriter, loc, acc, dim); - rewriter.create(loc, acc); + scf::YieldOp::create(rewriter, loc, acc); }; auto size = rewriter .create(loc, zero, rank, one, ValueRange(one), @@ -78,9 +78,9 @@ struct CloneOpConversion : public OpConversionPattern { // Allocate new memref with 1D dynamic shape, then reshape into the // shape of the original unranked memref - alloc = rewriter.create(loc, memrefType, size); + alloc = memref::AllocOp::create(rewriter, loc, memrefType, size); alloc = - rewriter.create(loc, unrankedType, alloc, shape); + memref::ReshapeOp::create(rewriter, loc, unrankedType, alloc, shape); } else { MemRefType memrefType = cast(type); MemRefLayoutAttrInterface layout; @@ -103,14 +103,15 @@ struct CloneOpConversion : public OpConversionPattern { } // Allocate a memref with identity layout. - alloc = rewriter.create(loc, allocType, dynamicOperands); + alloc = + memref::AllocOp::create(rewriter, loc, allocType, dynamicOperands); // Cast the allocation to the specified type if needed. if (memrefType != allocType) alloc = - rewriter.create(op->getLoc(), memrefType, alloc); + memref::CastOp::create(rewriter, op->getLoc(), memrefType, alloc); } - rewriter.create(loc, op.getInput(), alloc); + memref::CopyOp::create(rewriter, loc, op.getInput(), alloc); rewriter.replaceOp(op, alloc); return success(); }