-
Notifications
You must be signed in to change notification settings - Fork 14.6k
[mlir][NFC] update Conversion
create APIs (4/n)
#149879
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][NFC] update Conversion
create APIs (4/n)
#149879
Conversation
8c95839
to
75ef3da
Compare
See llvm#147168 for more info.
75ef3da
to
24edc8d
Compare
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-backend-amdgpu Author: Maksim Levental (makslevental) ChangesSee #147168 for more info. Patch is 134.65 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/149879.diff 12 Files Affected:
diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index ef35ee208f002..fe3dc91328879 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -50,20 +50,20 @@ static Value convertUnsignedToI32(ConversionPatternRewriter &rewriter,
if (i32 == valTy)
return val;
return valTy.getWidth() > 32
- ? Value(rewriter.create<LLVM::TruncOp>(loc, i32, val))
- : Value(rewriter.create<LLVM::ZExtOp>(loc, i32, val));
+ ? Value(LLVM::TruncOp::create(rewriter, loc, i32, val))
+ : Value(LLVM::ZExtOp::create(rewriter, loc, i32, val));
}
static Value createI32Constant(ConversionPatternRewriter &rewriter,
Location loc, int32_t value) {
Type i32 = rewriter.getI32Type();
- return rewriter.create<LLVM::ConstantOp>(loc, i32, value);
+ return LLVM::ConstantOp::create(rewriter, loc, i32, value);
}
static Value createI1Constant(ConversionPatternRewriter &rewriter, Location loc,
bool value) {
Type llvmI1 = rewriter.getI1Type();
- return rewriter.create<LLVM::ConstantOp>(loc, llvmI1, value);
+ return LLVM::ConstantOp::create(rewriter, loc, llvmI1, value);
}
/// Returns the linear index used to access an element in the memref.
@@ -78,11 +78,11 @@ static Value getLinearIndexI32(ConversionPatternRewriter &rewriter,
ShapedType::isDynamic(stride)
? convertUnsignedToI32(rewriter, loc,
memRefDescriptor.stride(rewriter, loc, i))
- : rewriter.create<LLVM::ConstantOp>(loc, i32, stride);
- increment = rewriter.create<LLVM::MulOp>(loc, increment, strideValue);
+ : LLVM::ConstantOp::create(rewriter, loc, i32, stride);
+ increment = LLVM::MulOp::create(rewriter, loc, increment, strideValue);
}
- index =
- index ? rewriter.create<LLVM::AddOp>(loc, index, increment) : increment;
+ index = index ? LLVM::AddOp::create(rewriter, loc, index, increment)
+ : increment;
}
return index ? index : createI32Constant(rewriter, loc, 0);
}
@@ -110,14 +110,14 @@ static Value getNumRecords(ConversionPatternRewriter &rewriter, Location loc,
for (uint32_t i = 0, e = memrefType.getRank(); i < e; ++i) {
Value size = memrefDescriptor.size(rewriter, loc, i);
Value stride = memrefDescriptor.stride(rewriter, loc, i);
- Value maxThisDim = rewriter.create<LLVM::MulOp>(loc, size, stride);
+ Value maxThisDim = LLVM::MulOp::create(rewriter, loc, size, stride);
maxIndex = maxIndex
- ? rewriter.create<LLVM::UMaxOp>(loc, maxIndex, maxThisDim)
+ ? LLVM::UMaxOp::create(rewriter, loc, maxIndex, maxThisDim)
: maxThisDim;
}
Value maxIndexI32 = convertUnsignedToI32(rewriter, loc, maxIndex);
Value byteWidthConst = createI32Constant(rewriter, loc, elementByteWidth);
- return rewriter.create<LLVM::MulOp>(loc, maxIndexI32, byteWidthConst);
+ return LLVM::MulOp::create(rewriter, loc, maxIndexI32, byteWidthConst);
}
static Value makeBufferRsrc(ConversionPatternRewriter &rewriter, Location loc,
@@ -132,14 +132,14 @@ static Value makeBufferRsrc(ConversionPatternRewriter &rewriter, Location loc,
Value stride;
if (chipset.majorVersion == 9 && chipset >= kGfx942 && cacheSwizzleStride) {
Value cacheStrideZext =
- rewriter.create<LLVM::ZExtOp>(loc, i16, cacheSwizzleStride);
- Value swizzleBit = rewriter.create<LLVM::ConstantOp>(
- loc, i16, rewriter.getI16IntegerAttr(1 << 14));
- stride = rewriter.create<LLVM::OrOp>(loc, cacheStrideZext, swizzleBit,
- /*isDisjoint=*/true);
+ LLVM::ZExtOp::create(rewriter, loc, i16, cacheSwizzleStride);
+ Value swizzleBit = LLVM::ConstantOp::create(
+ rewriter, loc, i16, rewriter.getI16IntegerAttr(1 << 14));
+ stride = LLVM::OrOp::create(rewriter, loc, cacheStrideZext, swizzleBit,
+ /*isDisjoint=*/true);
} else {
- stride = rewriter.create<LLVM::ConstantOp>(loc, i16,
- rewriter.getI16IntegerAttr(0));
+ stride = LLVM::ConstantOp::create(rewriter, loc, i16,
+ rewriter.getI16IntegerAttr(0));
}
// Get the number of elements.
// Flag word:
@@ -209,20 +209,21 @@ struct FatRawBufferCastLowering
: descriptor.alignedPtr(rewriter, loc);
Value offset = adaptor.getResetOffset()
- ? rewriter.create<LLVM::ConstantOp>(
- loc, getIndexType(), rewriter.getIndexAttr(0))
+ ? LLVM::ConstantOp::create(rewriter, loc, getIndexType(),
+ rewriter.getIndexAttr(0))
: descriptor.offset(rewriter, loc);
bool hasSizes = memrefType.getRank() > 0;
// No need to unpack() and pack() all the individual sizes and strides,
// so we'll just extract the arrays.
- Value sizes = hasSizes ? rewriter.create<LLVM::ExtractValueOp>(
- loc, descriptor, kSizePosInMemRefDescriptor)
- : Value{};
- Value strides = hasSizes
- ? rewriter.create<LLVM::ExtractValueOp>(
- loc, descriptor, kStridePosInMemRefDescriptor)
- : Value{};
+ Value sizes = hasSizes
+ ? LLVM::ExtractValueOp::create(rewriter, loc, descriptor,
+ kSizePosInMemRefDescriptor)
+ : Value{};
+ Value strides =
+ hasSizes ? LLVM::ExtractValueOp::create(rewriter, loc, descriptor,
+ kStridePosInMemRefDescriptor)
+ : Value{};
Value fatPtr = makeBufferRsrc(
rewriter, loc, basePointer, numRecords, adaptor.getBoundsCheck(),
@@ -231,17 +232,17 @@ struct FatRawBufferCastLowering
Value result = MemRefDescriptor::poison(
rewriter, loc,
getTypeConverter()->convertType(op.getResult().getType()));
- result = rewriter.create<LLVM::InsertValueOp>(
- loc, result, fatPtr, kAllocatedPtrPosInMemRefDescriptor);
- result = rewriter.create<LLVM::InsertValueOp>(
- loc, result, fatPtr, kAlignedPtrPosInMemRefDescriptor);
- result = rewriter.create<LLVM::InsertValueOp>(loc, result, offset,
- kOffsetPosInMemRefDescriptor);
+ result = LLVM::InsertValueOp::create(rewriter, loc, result, fatPtr,
+ kAllocatedPtrPosInMemRefDescriptor);
+ result = LLVM::InsertValueOp::create(rewriter, loc, result, fatPtr,
+ kAlignedPtrPosInMemRefDescriptor);
+ result = LLVM::InsertValueOp::create(rewriter, loc, result, offset,
+ kOffsetPosInMemRefDescriptor);
if (hasSizes) {
- result = rewriter.create<LLVM::InsertValueOp>(loc, result, sizes,
- kSizePosInMemRefDescriptor);
- result = rewriter.create<LLVM::InsertValueOp>(
- loc, result, strides, kStridePosInMemRefDescriptor);
+ result = LLVM::InsertValueOp::create(rewriter, loc, result, sizes,
+ kSizePosInMemRefDescriptor);
+ result = LLVM::InsertValueOp::create(rewriter, loc, result, strides,
+ kStridePosInMemRefDescriptor);
}
rewriter.replaceOp(op, result);
return success();
@@ -342,8 +343,8 @@ struct RawBufferOpLowering : public ConvertOpToLLVMPattern<GpuOp> {
SmallVector<Value, 6> args;
if (storeData) {
if (llvmBufferValType != llvmWantedDataType) {
- Value castForStore =
- rewriter.create<LLVM::BitcastOp>(loc, llvmBufferValType, storeData);
+ Value castForStore = LLVM::BitcastOp::create(
+ rewriter, loc, llvmBufferValType, storeData);
args.push_back(castForStore);
} else {
args.push_back(storeData);
@@ -352,8 +353,8 @@ struct RawBufferOpLowering : public ConvertOpToLLVMPattern<GpuOp> {
if (atomicCmpData) {
if (llvmBufferValType != llvmWantedDataType) {
- Value castForCmp = rewriter.create<LLVM::BitcastOp>(
- loc, llvmBufferValType, atomicCmpData);
+ Value castForCmp = LLVM::BitcastOp::create(
+ rewriter, loc, llvmBufferValType, atomicCmpData);
args.push_back(castForCmp);
} else {
args.push_back(atomicCmpData);
@@ -382,18 +383,18 @@ struct RawBufferOpLowering : public ConvertOpToLLVMPattern<GpuOp> {
if (std::optional<int32_t> indexOffset = adaptor.getIndexOffset();
indexOffset && *indexOffset > 0) {
Value extraOffsetConst = createI32Constant(rewriter, loc, *indexOffset);
- voffset =
- voffset ? rewriter.create<LLVM::AddOp>(loc, voffset, extraOffsetConst)
- : extraOffsetConst;
+ voffset = voffset ? LLVM::AddOp::create(rewriter, loc, voffset,
+ extraOffsetConst)
+ : extraOffsetConst;
}
- voffset = rewriter.create<LLVM::MulOp>(loc, voffset, byteWidthConst);
+ voffset = LLVM::MulOp::create(rewriter, loc, voffset, byteWidthConst);
args.push_back(voffset);
// SGPR offset.
Value sgprOffset = adaptor.getSgprOffset();
if (!sgprOffset)
sgprOffset = createI32Constant(rewriter, loc, 0);
- sgprOffset = rewriter.create<LLVM::MulOp>(loc, sgprOffset, byteWidthConst);
+ sgprOffset = LLVM::MulOp::create(rewriter, loc, sgprOffset, byteWidthConst);
args.push_back(sgprOffset);
// bit 0: GLC = 0 (atomics drop value, less coherency)
@@ -403,13 +404,13 @@ struct RawBufferOpLowering : public ConvertOpToLLVMPattern<GpuOp> {
llvm::SmallVector<Type, 1> resultTypes(gpuOp->getNumResults(),
llvmBufferValType);
- Operation *lowered = rewriter.create<Intrinsic>(loc, resultTypes, args,
- ArrayRef<NamedAttribute>());
+ Operation *lowered = Intrinsic::create(rewriter, loc, resultTypes, args,
+ ArrayRef<NamedAttribute>());
if (lowered->getNumResults() == 1) {
Value replacement = lowered->getResult(0);
if (llvmBufferValType != llvmWantedDataType) {
- replacement = rewriter.create<LLVM::BitcastOp>(loc, llvmWantedDataType,
- replacement);
+ replacement = LLVM::BitcastOp::create(rewriter, loc, llvmWantedDataType,
+ replacement);
}
rewriter.replaceOp(gpuOp, replacement);
} else {
@@ -465,12 +466,12 @@ struct LDSBarrierOpLowering : public ConvertOpToLLVMPattern<LDSBarrierOp> {
<< chipset.majorVersion;
Location loc = op->getLoc();
- rewriter.create<ROCDL::SWaitcntOp>(loc, ldsOnlyBits);
+ ROCDL::SWaitcntOp::create(rewriter, loc, ldsOnlyBits);
rewriter.replaceOpWithNewOp<ROCDL::SBarrierOp>(op);
} else {
Location loc = op->getLoc();
- rewriter.create<ROCDL::WaitDscntOp>(loc, 0);
- rewriter.create<ROCDL::BarrierSignalOp>(loc, -1);
+ ROCDL::WaitDscntOp::create(rewriter, loc, 0);
+ ROCDL::BarrierSignalOp::create(rewriter, loc, -1);
rewriter.replaceOpWithNewOp<ROCDL::BarrierWaitOp>(op, -1);
}
@@ -516,19 +517,21 @@ static Value convertMFMAVectorOperand(ConversionPatternRewriter &rewriter,
Type inputType = input.getType();
if (auto vectorType = dyn_cast<VectorType>(inputType)) {
if (vectorType.getElementType().isBF16() && !allowBf16)
- return rewriter.create<LLVM::BitcastOp>(
- loc, vectorType.clone(rewriter.getI16Type()), input);
+ return LLVM::BitcastOp::create(
+ rewriter, loc, vectorType.clone(rewriter.getI16Type()), input);
if (vectorType.getElementType().isInteger(8) &&
vectorType.getNumElements() <= 8)
- return rewriter.create<LLVM::BitcastOp>(
- loc, rewriter.getIntegerType(vectorType.getNumElements() * 8), input);
+ return LLVM::BitcastOp::create(
+ rewriter, loc,
+ rewriter.getIntegerType(vectorType.getNumElements() * 8), input);
if (isa<IntegerType>(vectorType.getElementType()) &&
vectorType.getElementTypeBitWidth() <= 8) {
int64_t numWords = llvm::divideCeil(
vectorType.getNumElements() * vectorType.getElementTypeBitWidth(),
32);
- return rewriter.create<LLVM::BitcastOp>(
- loc, VectorType::get(numWords, rewriter.getI32Type()), input);
+ return LLVM::BitcastOp::create(
+ rewriter, loc, VectorType::get(numWords, rewriter.getI32Type()),
+ input);
}
}
return input;
@@ -549,8 +552,8 @@ static Value castMFMAScaleOperand(ConversionPatternRewriter &rewriter,
Type inputType = input.getType();
Type outputType = rewriter.getI32Type();
if (auto intType = dyn_cast<IntegerType>(inputType))
- return rewriter.create<LLVM::ZExtOp>(loc, outputType, input);
- return rewriter.create<LLVM::BitcastOp>(loc, outputType, input);
+ return LLVM::ZExtOp::create(rewriter, loc, outputType, input);
+ return LLVM::BitcastOp::create(rewriter, loc, outputType, input);
}
/// Push an input operand. If it is a float type, nothing to do. If it is
@@ -576,8 +579,8 @@ static void wmmaPushInputOperand(ConversionPatternRewriter &rewriter,
Type elemType = vectorType.getElementType();
if (elemType.isBF16())
- llvmInput = rewriter.create<LLVM::BitcastOp>(
- loc, vectorType.clone(rewriter.getI16Type()), llvmInput);
+ llvmInput = LLVM::BitcastOp::create(
+ rewriter, loc, vectorType.clone(rewriter.getI16Type()), llvmInput);
if (elemType.getIntOrFloatBitWidth() > 8) {
operands.push_back(llvmInput);
return;
@@ -613,7 +616,7 @@ static void wmmaPushInputOperand(ConversionPatternRewriter &rewriter,
// (256 / 64) * 4 = 16 bits of input (on gfx12+) but take i32 arguments.
// Add in the zeros here.
if (numBits < 32)
- castInput = rewriter.create<LLVM::ZExtOp>(loc, i32, castInput);
+ castInput = LLVM::ZExtOp::create(rewriter, loc, i32, castInput);
operands.push_back(castInput);
}
@@ -633,8 +636,8 @@ static void wmmaPushOutputOperand(ConversionPatternRewriter &rewriter,
auto vectorType = dyn_cast<VectorType>(inputType);
Type elemType = vectorType.getElementType();
if (elemType.isBF16())
- output = rewriter.create<LLVM::BitcastOp>(
- loc, vectorType.clone(rewriter.getI16Type()), output);
+ output = LLVM::BitcastOp::create(
+ rewriter, loc, vectorType.clone(rewriter.getI16Type()), output);
operands.push_back(output);
if (elemType.isF16() || elemType.isBF16() || elemType.isInteger(16)) {
operands.push_back(createI1Constant(rewriter, loc, subwordOffset));
@@ -992,7 +995,7 @@ struct MFMAOpLowering : public ConvertOpToLLVMPattern<MFMAOp> {
};
Value lowered = rewriter.create(loweredOp)->getResult(0);
if (outType != intrinsicOutType)
- lowered = rewriter.create<LLVM::BitcastOp>(loc, outType, lowered);
+ lowered = LLVM::BitcastOp::create(rewriter, loc, outType, lowered);
rewriter.replaceOp(op, lowered);
return success();
}
@@ -1092,8 +1095,8 @@ struct WMMAOpLowering : public ConvertOpToLLVMPattern<WMMAOp> {
Operation *maybeCastBack = lowered;
if (rawOutType != outType)
- maybeCastBack =
- rewriter.create<LLVM::BitcastOp>(loc, outType, lowered->getResult(0));
+ maybeCastBack = LLVM::BitcastOp::create(rewriter, loc, outType,
+ lowered->getResult(0));
rewriter.replaceOp(op, maybeCastBack->getResults());
return success();
@@ -1143,22 +1146,22 @@ struct TransposeLoadOpLowering
switch (elementTypeSize) {
case 4: {
assert(numElements == 16);
- auto rocdlOp =
- rewriter.create<ROCDL::ds_read_tr4_b64>(loc, rocdlResultType, srcPtr);
+ auto rocdlOp = ROCDL::ds_read_tr4_b64::create(rewriter, loc,
+ rocdlResultType, srcPtr);
rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(op, llvmResultType, rocdlOp);
break;
}
case 6: {
assert(numElements == 16);
- auto rocdlOp =
- rewriter.create<ROCDL::ds_read_tr6_b96>(loc, rocdlResultType, srcPtr);
+ auto rocdlOp = ROCDL::ds_read_tr6_b96::create(rewriter, loc,
+ rocdlResultType, srcPtr);
rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(op, llvmResultType, rocdlOp);
break;
}
case 8: {
assert(numElements == 8);
- auto rocdlOp =
- rewriter.create<ROCDL::ds_read_tr8_b64>(loc, rocdlResultType, srcPtr);
+ auto rocdlOp = ROCDL::ds_read_tr8_b64::create(rewriter, loc,
+ rocdlResultType, srcPtr);
rewriter.replaceOpWithNewOp<LLVM::BitcastOp>(op, llvmResultType, rocdlOp);
break;
}
@@ -1316,21 +1319,21 @@ LogicalResult ExtPackedFp8OpLowering::matchAndRewrite(
Type sourceElemType = getElementTypeOrSelf(op.getSource());
// Extend to a v4i8
if (!sourceVecType || sourceVecType.getNumElements() < 4) {
- Value longVec = rewriter.create<LLVM::UndefOp>(loc, v4i8);
+ Value longVec = LLVM::UndefOp::create(rewriter, loc, v4i8);
if (!sourceVecType) {
- longVec = rewriter.create<LLVM::InsertElementOp>(
- loc, longVec, source, createI32Constant(rewriter, loc, 0));
+ longVec = LLVM::InsertElementOp::create(
+ rewriter, loc, longVec, source, createI32Constant(rewriter, loc, 0));
} else {
for (int32_t i = 0, e = sourceVecType.getNumElements(); i < e; ++i) {
Value idx = createI32Constant(rewriter, loc, i);
- Value elem = rewriter.create<LLVM::ExtractElementOp>(loc, source, idx);
+ Value elem = LLVM::ExtractElementOp::create(rewriter, loc, source, idx);
longVec =
- rewriter.create<LLVM::InsertElementOp>(loc, longVec, elem, idx);
+ LLVM::InsertElementOp::create(rewriter, loc, longVec, elem, idx);
}
}
source = longVec;
}
- Value i32Source = rewriter.create<LLVM::BitcastOp>(loc, i32, source);
+ Value i32Source = LLVM::BitcastOp::create(rewriter, loc, i32, source);
if (resultVecType) {
if (typeIsExpectedBf8ForChipset(chipset, sourceElemType)) {
rewriter.replaceOpWithNewOp<ROCDL::CvtPkF32Bf8Op>(op, f32, i32Source,
@@ -1382,21 +1385,21 @@ LogicalResult ScaledExtPackedOpLowering::matchAndRewrite(
// Extend to a packedVectorType
if (sourceVecType.getNumElements() < packedVecType.getNumElements()) {
- Value longVec = rewriter.create<LLVM::ZeroOp>(loc, packedVecType);
+ Value longVec = LLVM::ZeroOp::create(rewriter, loc, packedVecType);
if (!sourceVecType) {
- longVec = rewriter.create<LLVM::InsertElementOp>(
- loc, longVec, source, createI32Constant(rewriter, loc, 0));
+ longVec = LLVM::InsertElementOp::create(
+ rewriter, loc, longVec, source, createI32Constant(rewriter, loc, 0));
} else {
for (int32_t i = 0, e = sourceVecType.getNumElements(); i < e; ++i) {
Value idx = createI32Constant(rewriter, loc, i);
- Value elem = rewriter.create<LLVM::ExtractElementOp>(loc, source, idx);
+ Value elem = LLVM::ExtractElementOp::create(rewriter, loc, source, idx);
longVec =
- rewriter.create<LLVM::InsertElementOp>(loc, longVec, elem, idx);
+ LLVM::InsertElementOp::create(rewriter, loc, longVec, elem, idx);
}
}
source = longVec;
}
- Value i32Source = rewriter.create<LLVM::BitcastOp>(loc, i32, source);
+ Value i32Source = LLVM::BitcastOp::create(rewriter, loc, i32, source);
if (isa<Float8E5M2Type>(sourceElemType) && destElemType.isF32())
rewriter.replaceOpWithNewOp<ROCDL::CvtScaleF32PkF32Bf8Op>(
@@ -1454,54 +1457,57 @@ LogicalResult PackedScaledTruncOpLowering::matchAndRewrite(
Value scale = adaptor.getScale();
Value existing ...
[truncated]
|
Conversion
create APIs (4/n) (#149687)Conversion
create APIs (4/n)
ping |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Potentially useful for people rebasing with merge conflicts after this lands: I can reproduce this diff with
find . -name '*.cpp' -exec sed -E -i 's/(\w+)[[:space:]]*\.[[:space:]]*create<([^>]*)>\(/\2::create(\1, /g' {} +
on linux from within the appropriate directory.
@makslevental if your goal is to eliminate all create<.>
s you'll need to rerun your regex as there are already a sprinkling of new ones on main which aren't caught in this PR. But I suspect that's not your goal?
it is but I'll just come back around and do this again to catch the stragglers after this batch lands. |
See #147168 for more info.