Skip to content

Commit 46b7fd3

Browse files
committed
code review
1 parent 8c418a2 commit 46b7fd3

File tree

1 file changed

+9
-11
lines changed

1 file changed

+9
-11
lines changed

mlir/lib/Dialect/GPU/Transforms/PromoteShuffleToAMDGPU.cpp

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -101,16 +101,16 @@ static Value getLaneId(RewriterBase &rewriter, Location loc) {
101101
auto int32Type = IntegerType::get(rewriter.getContext(), 32);
102102
Value zero = arith::ConstantIntOp::create(rewriter, loc, 0, 32);
103103
Value minus1 = arith::ConstantIntOp::create(rewriter, loc, -1, 32);
104-
NamedAttribute noundef = rewriter.getNamedAttr(
105-
LLVM::LLVMDialect::getNoUndefAttrName(), rewriter.getUnitAttr());
106-
NamedAttribute lowRange = rewriter.getNamedAttr(
104+
NamedAttribute noundef = {LLVM::LLVMDialect::getNoUndefAttrName(),
105+
rewriter.getUnitAttr()};
106+
NamedAttribute lowRange = {LLVM::LLVMDialect::getRangeAttrName(),
107+
LLVM::ConstantRangeAttr::get(rewriter.getContext(),
108+
APInt::getZero(32),
109+
APInt(32, 32))};
110+
NamedAttribute highRange = {
107111
LLVM::LLVMDialect::getRangeAttrName(),
108112
LLVM::ConstantRangeAttr::get(rewriter.getContext(), APInt::getZero(32),
109-
APInt(32, 32)));
110-
NamedAttribute highRange = rewriter.getNamedAttr(
111-
LLVM::LLVMDialect::getRangeAttrName(),
112-
LLVM::ConstantRangeAttr::get(rewriter.getContext(), APInt::getZero(32),
113-
APInt(32, 64)));
113+
APInt(32, 64))};
114114
Value mbcntLo = ROCDL::MbcntLoOp::create(
115115
rewriter, loc, int32Type, minus1, zero, /*arg_attrs=*/{},
116116
/*res_attrs=*/
@@ -133,9 +133,7 @@ struct PromoteShuffleToDPPPattern : public OpRewritePattern<gpu::ShuffleOp> {
133133
return rewriter.notifyMatchFailure(op,
134134
"width must be a constant integer");
135135
int64_t widthValue = *width;
136-
if (widthValue != 4 && widthValue != 8 && widthValue != 12 &&
137-
widthValue != 16 && widthValue != 32 && widthValue != 48 &&
138-
widthValue != 64)
136+
if (!llvm::is_contained({4, 8, 12, 16, 32, 48, 64}, widthValue))
139137
return rewriter.notifyMatchFailure(
140138
op, "width must be 4, 8, 12, 16, 32, 48 or 64");
141139

0 commit comments

Comments
 (0)