Skip to content

[mlir][ods] Enable basic string interpolation in constraint summary. #153603

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 19 additions & 1 deletion mlir/include/mlir/TableGen/CodeGenHelpers.h
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,13 @@ class NamespaceEmitter {
SmallVector<StringRef, 2> namespaces;
};

enum class ErrorStreamType {
// Inside a string that's streamed into an InflightDiagnostic.
InString,
// Inside a string inside an OpError.
InsideOpError,
};

/// This class deduplicates shared operation verification code by emitting
/// static functions alongside the op definitions. These methods are local to
/// the definition file, and are invoked within the operation verify methods.
Expand Down Expand Up @@ -218,7 +225,8 @@ class StaticVerifierFunctionEmitter {

/// A generic function to emit constraints
void emitConstraints(const ConstraintMap &constraints, StringRef selfName,
const char *codeTemplate);
const char *codeTemplate,
ErrorStreamType errorStreamType);

/// Assign a unique name to a unique constraint.
std::string getUniqueName(StringRef kind, unsigned index);
Expand Down Expand Up @@ -269,6 +277,16 @@ std::string stringify(T &&t) {
apply(std::forward<T>(t));
}

/// Helper to generate a C++ streaming error messages from a given message.
/// Message can contain '{{...}}' placeholders that are substituted with
/// C-expressions via tgfmt. It would effectively convert:
/// "Failed to verify {{foo}}"
/// into:
/// "Failed to verify " << tgfmt(foo, &ctx)
std::string buildErrorStreamingString(
StringRef message, const FmtContext &ctx,
ErrorStreamType errorStreamType = ErrorStreamType::InString);

} // namespace tblgen
} // namespace mlir

Expand Down
91 changes: 80 additions & 11 deletions mlir/lib/TableGen/CodeGenHelpers.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,25 @@
//===----------------------------------------------------------------------===//

#include "mlir/TableGen/CodeGenHelpers.h"
#include "mlir/Support/LLVM.h"
#include "mlir/TableGen/Argument.h"
#include "mlir/TableGen/Attribute.h"
#include "mlir/TableGen/Format.h"
#include "mlir/TableGen/Operator.h"
#include "mlir/TableGen/Pattern.h"
#include "mlir/TableGen/Property.h"
#include "mlir/TableGen/Region.h"
#include "mlir/TableGen/Successor.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/Support/FormatVariadic.h"
#include "llvm/Support/Path.h"
#include "llvm/Support/raw_ostream.h"
#include "llvm/TableGen/Error.h"
#include "llvm/TableGen/Record.h"
#include <cassert>
#include <optional>
#include <string>

using namespace llvm;
using namespace mlir;
Expand Down Expand Up @@ -113,6 +127,56 @@ StringRef StaticVerifierFunctionEmitter::getRegionConstraintFn(
// Constraint Emission
//===----------------------------------------------------------------------===//

/// Helper to generate a C++ string expression from a given message.
/// Message can contain '{{...}}' placeholders that are substituted with
/// C-expressions via tgfmt.
std::string mlir::tblgen::buildErrorStreamingString(
StringRef message, const FmtContext &ctx, ErrorStreamType errorStreamType) {
std::string result;
raw_string_ostream os(result);

std::string msgStr = escapeString(message);
StringRef msg = msgStr;

// Split the message by '{{' and '}}' and build a streaming expression.
auto split = msg.split("{{");
if (split.second.empty()) {
os << split.first;
return msgStr;
}

os << split.first;
if (errorStreamType == ErrorStreamType::InsideOpError)
os << "\")";
else
os << '"';

msg = split.second;
while (!msg.empty()) {
split = msg.split("}}");
StringRef var = split.first;
StringRef rest = split.second;

os << " << " << tgfmt(var, &ctx);

if (rest.empty())
break;

split = rest.split("{{");
if (split.second.empty() &&
errorStreamType == ErrorStreamType::InsideOpError) {
// To enable having part of string post, this adds a parenthesis before
// the last string segment to match the exiting one.
os << " << (\"" << split.first;
} else {
os << " << \"" << split.first;
}
msg = split.second;
}

return os.str();
}

/// Code templates for emitting type, attribute, successor, and region
/// constraints. Each of these templates require the following arguments:
///
Expand Down Expand Up @@ -225,22 +289,24 @@ static ::llvm::LogicalResult {0}(

void StaticVerifierFunctionEmitter::emitConstraints(
const ConstraintMap &constraints, StringRef selfName,
const char *const codeTemplate) {
const char *const codeTemplate, ErrorStreamType errorStreamType) {
FmtContext ctx;
ctx.addSubst("_op", "*op").withSelf(selfName);

for (auto &it : constraints) {
os << formatv(codeTemplate, it.second,
tgfmt(it.first.getConditionTemplate(), &ctx),
escapeString(it.first.getSummary()));
buildErrorStreamingString(it.first.getSummary(), ctx));
}
}

void StaticVerifierFunctionEmitter::emitTypeConstraints() {
emitConstraints(typeConstraints, "type", typeConstraintCode);
emitConstraints(typeConstraints, "type", typeConstraintCode,
ErrorStreamType::InString);
}

void StaticVerifierFunctionEmitter::emitAttrConstraints() {
emitConstraints(attrConstraints, "attr", attrConstraintCode);
emitConstraints(attrConstraints, "attr", attrConstraintCode,
ErrorStreamType::InString);
}

/// Unlike with the other helpers, this one has to substitute in the interface
Expand All @@ -252,17 +318,19 @@ void StaticVerifierFunctionEmitter::emitPropConstraints() {
auto propConstraint = cast<PropConstraint>(it.first);
os << formatv(propConstraintCode, it.second,
tgfmt(propConstraint.getConditionTemplate(), &ctx),
escapeString(it.first.getSummary()),
buildErrorStreamingString(it.first.getSummary(), ctx),
propConstraint.getInterfaceType());
}
}

void StaticVerifierFunctionEmitter::emitSuccessorConstraints() {
emitConstraints(successorConstraints, "successor", successorConstraintCode);
emitConstraints(successorConstraints, "successor", successorConstraintCode,
ErrorStreamType::InString);
}

void StaticVerifierFunctionEmitter::emitRegionConstraints() {
emitConstraints(regionConstraints, "region", regionConstraintCode);
emitConstraints(regionConstraints, "region", regionConstraintCode,
ErrorStreamType::InString);
}

void StaticVerifierFunctionEmitter::emitPatternConstraints() {
Expand All @@ -271,13 +339,14 @@ void StaticVerifierFunctionEmitter::emitPatternConstraints() {
for (auto &it : typeConstraints) {
os << formatv(patternConstraintCode, it.second,
tgfmt(it.first.getConditionTemplate(), &ctx),
escapeString(it.first.getSummary()), "::mlir::Type type");
buildErrorStreamingString(it.first.getSummary(), ctx),
"::mlir::Type type");
}
ctx.withSelf("attr");
for (auto &it : attrConstraints) {
os << formatv(patternConstraintCode, it.second,
tgfmt(it.first.getConditionTemplate(), &ctx),
escapeString(it.first.getSummary()),
buildErrorStreamingString(it.first.getSummary(), ctx),
"::mlir::Attribute attr");
}
ctx.withSelf("prop");
Expand All @@ -292,7 +361,7 @@ void StaticVerifierFunctionEmitter::emitPatternConstraints() {
}
os << formatv(patternConstraintCode, it.second,
tgfmt(propConstraint.getConditionTemplate(), &ctx),
escapeString(propConstraint.getSummary()),
buildErrorStreamingString(propConstraint.getSummary(), ctx),
Twine(interfaceType) + " prop");
}
}
Expand Down
6 changes: 3 additions & 3 deletions mlir/test/mlir-tblgen/constraint-unique.td
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def AType : Type<ATypePred, "a type">;
def OtherType : Type<ATypePred, "another type">;

def AnAttrPred : CPred<"attrPred($_self, $_op)">;
def AnAttr : Attr<AnAttrPred, "an attribute">;
def AnAttr : Attr<AnAttrPred, "an attribute (got {{reformat($_self)}})">;
def OtherAttr : Attr<AnAttrPred, "another attribute">;

def ASuccessorPred : CPred<"successorPred($_self, $_op)">;
Expand Down Expand Up @@ -71,10 +71,10 @@ def OpC : NS_Op<"op_c"> {
// CHECK: static ::llvm::LogicalResult [[$A_ATTR_CONSTRAINT:__mlir_ods_local_attr_constraint.*]](
// CHECK: if (attr && !((attrPred(attr, *op))))
// CHECK-NEXT: return emitError() << "attribute '" << attrName
// CHECK-NEXT: << "' failed to satisfy constraint: an attribute";
// CHECK-NEXT: << "' failed to satisfy constraint: an attribute (got " << reformat(attr) << ")";

/// Test that duplicate attribute constraint was not generated.
// CHECK-NOT: << "' failed to satisfy constraint: an attribute";
// CHECK-NOT: << "' failed to satisfy constraint: an attribute

/// Test that a attribute constraint with a different description was generated.
// CHECK: static ::llvm::LogicalResult [[$O_ATTR_CONSTRAINT:__mlir_ods_local_attr_constraint.*]](
Expand Down
16 changes: 8 additions & 8 deletions mlir/test/mlir-tblgen/op-attribute.td
Original file line number Diff line number Diff line change
Expand Up @@ -69,19 +69,19 @@ def AOp : NS_Op<"a_op", []> {

// DEF: ::llvm::LogicalResult AOpAdaptor::verify
// DEF-NEXT: auto tblgen_aAttr = getProperties().aAttr; (void)tblgen_aAttr;
// DEF-NEXT: if (!tblgen_aAttr) return emitError(loc, "'test.a_op' op ""requires attribute 'aAttr'");
// DEF-NEXT: if (!tblgen_aAttr) return emitError(loc, "'test.a_op' op requires attribute 'aAttr'");
// DEF-NEXT: auto tblgen_bAttr = getProperties().bAttr; (void)tblgen_bAttr;
// DEF-NEXT: auto tblgen_cAttr = getProperties().cAttr; (void)tblgen_cAttr;
// DEF-NEXT: auto tblgen_dAttr = getProperties().dAttr; (void)tblgen_dAttr;

// DEF: if (tblgen_aAttr && !((some-condition)))
// DEF-NEXT: return emitError(loc, "'test.a_op' op ""attribute 'aAttr' failed to satisfy constraint: some attribute kind");
// DEF-NEXT: return emitError(loc, "'test.a_op' op attribute 'aAttr' failed to satisfy constraint: some attribute kind");
// DEF: if (tblgen_bAttr && !((some-condition)))
// DEF-NEXT: return emitError(loc, "'test.a_op' op ""attribute 'bAttr' failed to satisfy constraint: some attribute kind");
// DEF-NEXT: return emitError(loc, "'test.a_op' op attribute 'bAttr' failed to satisfy constraint: some attribute kind");
// DEF: if (tblgen_cAttr && !((some-condition)))
// DEF-NEXT: return emitError(loc, "'test.a_op' op ""attribute 'cAttr' failed to satisfy constraint: some attribute kind");
// DEF-NEXT: return emitError(loc, "'test.a_op' op attribute 'cAttr' failed to satisfy constraint: some attribute kind");
// DEF: if (tblgen_dAttr && !((some-condition)))
// DEF-NEXT: return emitError(loc, "'test.a_op' op ""attribute 'dAttr' failed to satisfy constraint: some attribute kind");
// DEF-NEXT: return emitError(loc, "'test.a_op' op attribute 'dAttr' failed to satisfy constraint: some attribute kind");

// Test getter methods
// ---
Expand Down Expand Up @@ -219,13 +219,13 @@ def AgetOp : Op<Test2_Dialect, "a_get_op", []> {

// DEF: ::llvm::LogicalResult AgetOpAdaptor::verify
// DEF: auto tblgen_aAttr = getProperties().aAttr; (void)tblgen_aAttr;
// DEF: if (!tblgen_aAttr) return emitError(loc, "'test2.a_get_op' op ""requires attribute 'aAttr'");
// DEF: if (!tblgen_aAttr) return emitError(loc, "'test2.a_get_op' op requires attribute 'aAttr'");
// DEF: auto tblgen_bAttr = getProperties().bAttr; (void)tblgen_bAttr;
// DEF: auto tblgen_cAttr = getProperties().cAttr; (void)tblgen_cAttr;
// DEF: if (tblgen_bAttr && !((some-condition)))
// DEF-NEXT: return emitError(loc, "'test2.a_get_op' op ""attribute 'bAttr' failed to satisfy constraint: some attribute kind");
// DEF-NEXT: return emitError(loc, "'test2.a_get_op' op attribute 'bAttr' failed to satisfy constraint: some attribute kind");
// DEF: if (tblgen_cAttr && !((some-condition)))
// DEF-NEXT: return emitError(loc, "'test2.a_get_op' op ""attribute 'cAttr' failed to satisfy constraint: some attribute kind");
// DEF-NEXT: return emitError(loc, "'test2.a_get_op' op attribute 'cAttr' failed to satisfy constraint: some attribute kind");

// Test getter methods
// ---
Expand Down
2 changes: 1 addition & 1 deletion mlir/test/mlir-tblgen/op-properties-predicates.td
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def OpWithPredicates : NS_Op<"op_with_predicates"> {
// Note: comprehensive emission of verifiers is tested in verifyINvariantsImpl() below
// CHECK: int64_t tblgen_scalar = this->getScalar();
// CHECK: if (!((tblgen_scalar >= 0)))
// CHECK: return emitError(loc, "'test.op_with_predicates' op ""property 'scalar' failed to satisfy constraint: non-negative int64_t");
// CHECK: return emitError(loc, "'test.op_with_predicates' op property 'scalar' failed to satisfy constraint: non-negative int64_t");

// CHECK-LABEL: OpWithPredicates::verifyInvariantsImpl()
// Note: for test readability, we capture [[maybe_unused]] into the variable maybe_unused
Expand Down
16 changes: 8 additions & 8 deletions mlir/test/mlir-tblgen/predicate.td
Original file line number Diff line number Diff line change
Expand Up @@ -55,55 +55,55 @@ def OpF : NS_Op<"op_for_int_min_val", []> {

// CHECK-LABEL: OpFAdaptor::verify
// CHECK: (::llvm::cast<::mlir::IntegerAttr>(tblgen_attr).getInt() >= 10)
// CHECK-NEXT: "attribute 'attr' failed to satisfy constraint: 32-bit signless integer attribute whose minimum value is 10"
// CHECK-NEXT: attribute 'attr' failed to satisfy constraint: 32-bit signless integer attribute whose minimum value is 10"

def OpFX : NS_Op<"op_for_int_max_val", []> {
let arguments = (ins ConfinedAttr<I32Attr, [IntMaxValue<10>]>:$attr);
}

// CHECK-LABEL: OpFXAdaptor::verify
// CHECK: (::llvm::cast<::mlir::IntegerAttr>(tblgen_attr).getInt() <= 10)
// CHECK-NEXT: "attribute 'attr' failed to satisfy constraint: 32-bit signless integer attribute whose maximum value is 10"
// CHECK-NEXT: attribute 'attr' failed to satisfy constraint: 32-bit signless integer attribute whose maximum value is 10"

def OpG : NS_Op<"op_for_arr_min_count", []> {
let arguments = (ins ConfinedAttr<ArrayAttr, [ArrayMinCount<8>]>:$attr);
}

// CHECK-LABEL: OpGAdaptor::verify
// CHECK: (::llvm::cast<::mlir::ArrayAttr>(tblgen_attr).size() >= 8)
// CHECK-NEXT: "attribute 'attr' failed to satisfy constraint: array attribute with at least 8 elements"
// CHECK-NEXT: attribute 'attr' failed to satisfy constraint: array attribute with at least 8 elements"

def OpH : NS_Op<"op_for_arr_value_at_index", []> {
let arguments = (ins ConfinedAttr<ArrayAttr, [IntArrayNthElemEq<0, 8>]>:$attr);
}

// CHECK-LABEL: OpHAdaptor::verify
// CHECK: (((::llvm::cast<::mlir::ArrayAttr>(tblgen_attr).size() > 0)) && ((::llvm::cast<::mlir::IntegerAttr>(::llvm::cast<::mlir::ArrayAttr>(tblgen_attr)[0]).getInt() == 8)))))
// CHECK-NEXT: "attribute 'attr' failed to satisfy constraint: array attribute whose 0-th element must be 8"
// CHECK-NEXT: attribute 'attr' failed to satisfy constraint: array attribute whose 0-th element must be 8"

def OpI: NS_Op<"op_for_arr_min_value_at_index", []> {
let arguments = (ins ConfinedAttr<ArrayAttr, [IntArrayNthElemMinValue<0, 8>]>:$attr);
}

// CHECK-LABEL: OpIAdaptor::verify
// CHECK: (((::llvm::cast<::mlir::ArrayAttr>(tblgen_attr).size() > 0)) && ((::llvm::cast<::mlir::IntegerAttr>(::llvm::cast<::mlir::ArrayAttr>(tblgen_attr)[0]).getInt() >= 8)))))
// CHECK-NEXT: "attribute 'attr' failed to satisfy constraint: array attribute whose 0-th element must be at least 8"
// CHECK-NEXT: attribute 'attr' failed to satisfy constraint: array attribute whose 0-th element must be at least 8"

def OpJ: NS_Op<"op_for_arr_max_value_at_index", []> {
let arguments = (ins ConfinedAttr<ArrayAttr, [IntArrayNthElemMaxValue<0, 8>]>:$attr);
}

// CHECK-LABEL: OpJAdaptor::verify
// CHECK: (((::llvm::cast<::mlir::ArrayAttr>(tblgen_attr).size() > 0)) && ((::llvm::cast<::mlir::IntegerAttr>(::llvm::cast<::mlir::ArrayAttr>(tblgen_attr)[0]).getInt() <= 8)))))
// CHECK-NEXT: "attribute 'attr' failed to satisfy constraint: array attribute whose 0-th element must be at most 8"
// CHECK-NEXT: attribute 'attr' failed to satisfy constraint: array attribute whose 0-th element must be at most 8"

def OpK: NS_Op<"op_for_arr_in_range_at_index", []> {
let arguments = (ins ConfinedAttr<ArrayAttr, [IntArrayNthElemInRange<0, 4, 8>]>:$attr);
}

// CHECK-LABEL: OpKAdaptor::verify
// CHECK: (((::llvm::cast<::mlir::ArrayAttr>(tblgen_attr).size() > 0)) && ((::llvm::cast<::mlir::IntegerAttr>(::llvm::cast<::mlir::ArrayAttr>(tblgen_attr)[0]).getInt() >= 4)) && ((::llvm::cast<::mlir::IntegerAttr>(::llvm::cast<::mlir::ArrayAttr>(tblgen_attr)[0]).getInt() <= 8)))))
// CHECK-NEXT: "attribute 'attr' failed to satisfy constraint: array attribute whose 0-th element must be at least 4 and at most 8"
// CHECK-NEXT: attribute 'attr' failed to satisfy constraint: array attribute whose 0-th element must be at least 4 and at most 8"

def OpL: NS_Op<"op_for_TCopVTEtAreSameAt", [
PredOpTrait<"operands indexed at 0, 2, 3 should all have "
Expand All @@ -121,7 +121,7 @@ def OpL: NS_Op<"op_for_TCopVTEtAreSameAt", [
// CHECK: ::llvm::all_equal(::llvm::map_range(
// CHECK-SAME: ::mlir::ArrayRef<unsigned>({0, 2, 3}),
// CHECK-SAME: [this](unsigned i) { return getElementTypeOrSelf(this->getOperand(i)); }))
// CHECK: "failed to verify that operands indexed at 0, 2, 3 should all have the same type"
// CHECK: failed to verify that operands indexed at 0, 2, 3 should all have the same type"

def OpM : NS_Op<"op_for_AnyTensorOf", []> {
let arguments = (ins TensorOf<[F32, I32]>:$x);
Expand Down
Loading