Skip to content

Commit 8e7b02f

Browse files
Avoid copies in getChecked (#147721)
Following-up on #68067 ; adding std::move to getChecked method as well.
1 parent 73245b0 commit 8e7b02f

File tree

6 files changed

+44
-7
lines changed

6 files changed

+44
-7
lines changed

mlir/include/mlir/IR/StorageUniquerSupport.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -200,7 +200,7 @@ class StorageUserBase : public BaseT, public Traits<ConcreteT>... {
200200
// If the construction invariants fail then we return a null attribute.
201201
if (failed(ConcreteT::verifyInvariants(emitErrorFn, args...)))
202202
return ConcreteT();
203-
return UniquerT::template get<ConcreteT>(ctx, args...);
203+
return UniquerT::template get<ConcreteT>(ctx, std::forward<Args>(args)...);
204204
}
205205

206206
/// Get an instance of the concrete type from a void pointer.

mlir/test/lib/Dialect/Test/TestAttrDefs.td

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -347,6 +347,7 @@ def TestCopyCount : Test_Attr<"TestCopyCount"> {
347347
let mnemonic = "copy_count";
348348
let parameters = (ins TestParamCopyCount:$copy_count);
349349
let assemblyFormat = "`<` $copy_count `>`";
350+
let genVerifyDecl = 1;
350351
}
351352

352353
def TestConditionalAliasAttr : Test_Attr<"TestConditionalAlias"> {

mlir/test/lib/Dialect/Test/TestAttributes.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -213,6 +213,16 @@ static void printTrueFalse(AsmPrinter &p, std::optional<int> result) {
213213
p << (*result ? "true" : "false");
214214
}
215215

216+
//===----------------------------------------------------------------------===//
217+
// TestCopyCountAttr Implementation
218+
//===----------------------------------------------------------------------===//
219+
220+
LogicalResult TestCopyCountAttr::verify(
221+
llvm::function_ref<::mlir::InFlightDiagnostic()> /*emitError*/,
222+
CopyCount /*copy_count*/) {
223+
return success();
224+
}
225+
216226
//===----------------------------------------------------------------------===//
217227
// CopyCountAttr Implementation
218228
//===----------------------------------------------------------------------===//

mlir/test/mlir-tblgen/attrdefs.td

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,11 @@ def B_CompoundAttrA : TestAttr<"CompoundA"> {
115115
// DEF: return new (allocator.allocate<CompoundAAttrStorage>())
116116
// DEF-SAME: CompoundAAttrStorage(std::move(widthOfSomething), std::move(exampleTdType), std::move(apFloat), std::move(dims), std::move(inner));
117117

118+
// DEF: CompoundAAttr CompoundAAttr::getChecked(
119+
// DEF-SAME: int widthOfSomething, ::test::SimpleTypeA exampleTdType, ::llvm::APFloat apFloat, ::llvm::ArrayRef<int> dims, ::mlir::Type inner
120+
// DEF-SAME: )
121+
// DEF-NEXT: return Base::getChecked(emitError, context, std::move(widthOfSomething), std::move(exampleTdType), std::move(apFloat), std::move(dims), std::move(inner));
122+
118123
// DEF: ::mlir::Type CompoundAAttr::getInner() const {
119124
// DEF-NEXT: return getImpl()->inner;
120125
}

mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -495,7 +495,7 @@ void DefGen::emitCheckedBuilder() {
495495
MethodBody &body = m->body().indent();
496496
auto scope = body.scope("return Base::getChecked(emitError, context", ");");
497497
for (const auto &param : params)
498-
body << ", " << param.getName();
498+
body << ", std::move(" << param.getName() << ")";
499499
}
500500

501501
static SmallVector<MethodParameter>

mlir/unittests/IR/AttributeTest.cpp

Lines changed: 26 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -477,8 +477,9 @@ TEST(SubElementTest, Nested) {
477477
{strAttr, trueAttr, falseAttr, boolArrayAttr, dictAttr}));
478478
}
479479

480-
// Test how many times we call copy-ctor when building an attribute.
481-
TEST(CopyCountAttr, CopyCount) {
480+
// Test how many times we call copy-ctor when building an attribute with the
481+
// 'get' method.
482+
TEST(CopyCountAttr, CopyCountGet) {
482483
MLIRContext context;
483484
context.loadDialect<test::TestDialect>();
484485

@@ -489,15 +490,35 @@ TEST(CopyCountAttr, CopyCount) {
489490
test::CopyCount::counter = 0;
490491
test::TestCopyCountAttr::get(&context, std::move(copyCount));
491492
#ifndef NDEBUG
492-
// One verification enabled only in assert-mode requires a copy.
493-
EXPECT_EQ(counter1, 1);
494-
EXPECT_EQ(test::CopyCount::counter, 1);
493+
// One verification enabled only in assert-mode requires two copies: one for
494+
// calling 'verifyInvariants' and one for calling 'verify' inside
495+
// 'verifyInvariants'.
496+
EXPECT_EQ(counter1, 2);
497+
EXPECT_EQ(test::CopyCount::counter, 2);
495498
#else
496499
EXPECT_EQ(counter1, 0);
497500
EXPECT_EQ(test::CopyCount::counter, 0);
498501
#endif
499502
}
500503

504+
// Test how many times we call copy-ctor when building an attribute with the
505+
// 'getChecked' method.
506+
TEST(CopyCountAttr, CopyCountGetChecked) {
507+
MLIRContext context;
508+
context.loadDialect<test::TestDialect>();
509+
test::CopyCount::counter = 0;
510+
test::CopyCount copyCount("hello");
511+
auto loc = UnknownLoc::get(&context);
512+
test::TestCopyCountAttr::getChecked(loc, &context, std::move(copyCount));
513+
int counter1 = test::CopyCount::counter;
514+
test::CopyCount::counter = 0;
515+
test::TestCopyCountAttr::getChecked(loc, &context, std::move(copyCount));
516+
// The verifiers require two copies: one for calling 'verifyInvariants' and
517+
// one for calling 'verify' inside 'verifyInvariants'.
518+
EXPECT_EQ(counter1, 2);
519+
EXPECT_EQ(test::CopyCount::counter, 2);
520+
}
521+
501522
// Test stripped printing using test dialect attribute.
502523
TEST(CopyCountAttr, PrintStripped) {
503524
MLIRContext context;

0 commit comments

Comments
 (0)