-
Notifications
You must be signed in to change notification settings - Fork 15k
[OpenMP][mlir] Add DynGroupPrivateClause in omp dialect #153562
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
✅ With the latest revision this PR passed the C/C++ code formatter. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this generally looks good to me. Please make sure that this protected under -fopenmp-version=61
I haven't seen any version check for any clause in omp dialect. Flang has checks for version when codegen to omp dialect. |
@llvm/pr-subscribers-mlir-llvm @llvm/pr-subscribers-flang-openmp Author: Chaitanya (skc7) Changes
This PR enables dyn_groupprivate clause in openmp mlir dialect and adds it to Teams and Target ops. Full diff: https://github.com/llvm/llvm-project/pull/153562.diff 7 Files Affected:
diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
index 311c57fb4446c..3c55b860d8a42 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
@@ -1470,4 +1470,47 @@ class OpenMP_UseDevicePtrClauseSkip<
def OpenMP_UseDevicePtrClause : OpenMP_UseDevicePtrClauseSkip<>;
+//===----------------------------------------------------------------------===//
+// V6.1 `dyn_groupprivate` clause
+//===----------------------------------------------------------------------===//
+
+class OpenMP_DynGroupprivateClauseSkip<
+ bit traits = false, bit arguments = false, bit assemblyFormat = false,
+ bit description = false, bit extraClassDeclaration = false
+ > : OpenMP_Clause<traits, arguments, assemblyFormat, description,
+ extraClassDeclaration> {
+
+ let arguments = (ins
+ OptionalAttr<DynGroupprivateModifierAttr>:$modifier_first,
+ OptionalAttr<DynGroupprivateModifierAttr>:$modifier_second,
+ Optional<AnyInteger>:$dyn_groupprivate_size
+ );
+
+ let description = [{
+ The `dyn_groupprivate` clause allows you to dynamically allocate group-private
+ memory in OpenMP parallel regions, specifically for `target` and `teams` directives.
+ This clause enables runtime-sized private memory allocation and applicable to
+ target and teams ops.
+
+ Syntax:
+ ```
+ dyn_groupprivate(modifier_first ,modifier_second : dyn_groupprivate_size)
+ ```
+
+ Example:
+ ```
+ omp.target dyn_groupprivate(strict, cgroup : %dyn_groupprivate_size : i32)
+ ```
+ }];
+
+ let optAssemblyFormat = [{
+ `dyn_groupprivate` `(`
+ custom<DynGroupprivateClause>($modifier_first, $modifier_second,
+ $dyn_groupprivate_size, type($dyn_groupprivate_size))
+ `)`
+ }];
+}
+
+def OpenMP_DynGroupprivateClause : OpenMP_DynGroupprivateClauseSkip<>;
+
#endif // OPENMP_CLAUSES
diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPEnums.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPEnums.td
index c080c3fac87d4..9869615f5e885 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPEnums.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPEnums.td
@@ -293,4 +293,26 @@ def AllocatorHandle : OpenMP_I32EnumAttr<
]>;
def AllocatorHandleAttr : OpenMP_EnumAttr<AllocatorHandle, "allocator_handle">;
+
+//===----------------------------------------------------------------------===//
+// dyn_groupprivate enum.
+//===----------------------------------------------------------------------===//
+
+def DynGroupprivateCGroup : I32EnumAttrCase<"cgroup", 0>;
+def DynGroupprivateStrict : I32EnumAttrCase<"strict", 1>;
+def DynGroupprivateFallback : I32EnumAttrCase<"fallback", 2>;
+
+def DynGroupprivateModifier : OpenMP_I32EnumAttr<
+ "DynGroupprivateModifier",
+ "dyn_groupprivate modifier", [
+ DynGroupprivateCGroup,
+ DynGroupprivateStrict,
+ DynGroupprivateFallback
+ ]>;
+
+def DynGroupprivateModifierAttr : OpenMP_EnumAttr<DynGroupprivateModifier,
+ "dyn_groupprivate_modifier"> {
+ let assemblyFormat = "`(` $value `)`";
+}
+
#endif // OPENMP_ENUMS
diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
index be114ea4fb631..315233102be47 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
@@ -241,7 +241,8 @@ def TeamsOp : OpenMP_Op<"teams", traits = [
AttrSizedOperandSegments, RecursiveMemoryEffects, OutlineableOpenMPOpInterface
], clauses = [
OpenMP_AllocateClause, OpenMP_IfClause, OpenMP_NumTeamsClause,
- OpenMP_PrivateClause, OpenMP_ReductionClause, OpenMP_ThreadLimitClause
+ OpenMP_PrivateClause, OpenMP_ReductionClause, OpenMP_ThreadLimitClause,
+ OpenMP_DynGroupprivateClause
], singleRegion = true> {
let summary = "teams construct";
let description = [{
@@ -1464,7 +1465,7 @@ def TargetOp : OpenMP_Op<"target", traits = [
OpenMP_DeviceClause, OpenMP_HasDeviceAddrClause, OpenMP_HostEvalClause,
OpenMP_IfClause, OpenMP_InReductionClause, OpenMP_IsDevicePtrClause,
OpenMP_MapClauseSkip<assemblyFormat = true>, OpenMP_NowaitClause,
- OpenMP_PrivateClause, OpenMP_ThreadLimitClause
+ OpenMP_PrivateClause, OpenMP_ThreadLimitClause, OpenMP_DynGroupprivateClause
], singleRegion = true> {
let summary = "target construct";
let description = [{
diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
index c1c1767ef90b0..b70dbb7987208 100644
--- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
+++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
@@ -566,6 +566,167 @@ static void printNumTasksClause(OpAsmPrinter &p, Operation *op,
p, op, numTasksMod, numTasks, numTasksType, &stringifyClauseNumTasksType);
}
+//===----------------------------------------------------------------------===//
+// Parser, printer and verify for dyn_groupprivate Clause
+//===----------------------------------------------------------------------===//
+
+static LogicalResult verifyDynGroupprivateClause(
+ Operation *op, DynGroupprivateModifierAttr modifierFirst,
+ DynGroupprivateModifierAttr modifierSecond, Value dynGroupprivateSize) {
+
+ // Helper to get modifier name as string
+ auto getModifierName = [](DynGroupprivateModifier mod) -> StringRef {
+ switch (mod) {
+ case DynGroupprivateModifier::strict:
+ return "strict";
+ case DynGroupprivateModifier::cgroup:
+ return "cgroup";
+ case DynGroupprivateModifier::fallback:
+ return "fallback";
+ }
+ return "unknown";
+ };
+
+ // Check for duplicate modifiers
+ if (modifierFirst && modifierSecond &&
+ modifierFirst.getValue() == modifierSecond.getValue()) {
+ return op->emitOpError("duplicate dyn_groupprivate modifier '")
+ << getModifierName(modifierFirst.getValue()) << "'";
+ }
+
+ // Check for incompatible modifier combinations
+ if (modifierFirst && modifierSecond) {
+ auto m1 = modifierFirst.getValue();
+ auto m2 = modifierSecond.getValue();
+
+ // strict and fallback are incompatible
+ if ((m1 == DynGroupprivateModifier::strict &&
+ m2 == DynGroupprivateModifier::fallback) ||
+ (m1 == DynGroupprivateModifier::fallback &&
+ m2 == DynGroupprivateModifier::strict)) {
+ return op->emitOpError("incompatible dyn_groupprivate modifiers: '")
+ << getModifierName(m1) << "' and '" << getModifierName(m2)
+ << "' cannot be used together";
+ }
+ }
+
+ // Verify the size
+ if (dynGroupprivateSize) {
+ Type size_type = dynGroupprivateSize.getType();
+ // Check if the size type is an integer type
+ if (!size_type.isIntOrIndex()) {
+ return op->emitOpError(
+ "dyn_groupprivate size must be an integer type, got ")
+ << size_type;
+ }
+ }
+
+ return success();
+}
+
+static ParseResult parseDynGroupprivateClause(
+ OpAsmParser &parser, DynGroupprivateModifierAttr &modifierFirst,
+ DynGroupprivateModifierAttr &modifierSecond,
+ std::optional<OpAsmParser::UnresolvedOperand> &dynGroupprivateSize,
+ Type &size_type) {
+
+ bool hasModifiers = false;
+
+ // Parse first modifier if present
+ if (succeeded(parser.parseOptionalKeyword("strict"))) {
+ modifierFirst = DynGroupprivateModifierAttr::get(
+ parser.getContext(), DynGroupprivateModifier::strict);
+ hasModifiers = true;
+ } else if (succeeded(parser.parseOptionalKeyword("cgroup"))) {
+ modifierFirst = DynGroupprivateModifierAttr::get(
+ parser.getContext(), DynGroupprivateModifier::cgroup);
+ hasModifiers = true;
+ } else if (succeeded(parser.parseOptionalKeyword("fallback"))) {
+ modifierFirst = DynGroupprivateModifierAttr::get(
+ parser.getContext(), DynGroupprivateModifier::fallback);
+ hasModifiers = true;
+ }
+
+ // If first modifier found, check for comma and second modifier
+ if (hasModifiers && succeeded(parser.parseOptionalComma())) {
+ if (succeeded(parser.parseOptionalKeyword("strict"))) {
+ modifierSecond = DynGroupprivateModifierAttr::get(
+ parser.getContext(), DynGroupprivateModifier::strict);
+ } else if (succeeded(parser.parseOptionalKeyword("cgroup"))) {
+ modifierSecond = DynGroupprivateModifierAttr::get(
+ parser.getContext(), DynGroupprivateModifier::cgroup);
+ } else if (succeeded(parser.parseOptionalKeyword("fallback"))) {
+ modifierSecond = DynGroupprivateModifierAttr::get(
+ parser.getContext(), DynGroupprivateModifier::fallback);
+ } else {
+ return parser.emitError(parser.getCurrentLocation(),
+ "expected modifier after comma");
+ }
+ }
+
+ // Parse colon and size if modifiers were present, or just try to parse
+ // operand
+ if (hasModifiers) {
+ // Modifiers present, expect colon
+ if (failed(parser.parseColon())) {
+ return parser.emitError(parser.getCurrentLocation(),
+ "expected ':' after modifiers");
+ }
+
+ // Parse operand and type
+ OpAsmParser::UnresolvedOperand operand;
+ if (succeeded(parser.parseOperand(operand))) {
+ dynGroupprivateSize = operand;
+ if (failed(parser.parseColon()) || failed(parser.parseType(size_type))) {
+ return parser.emitError(parser.getCurrentLocation(),
+ "expected ':' and type after size operand");
+ }
+ }
+ } else {
+ // No modifiers, try to parse operand directly
+ OpAsmParser::UnresolvedOperand operand;
+ if (succeeded(parser.parseOperand(operand))) {
+ dynGroupprivateSize = operand;
+ if (failed(parser.parseColon()) || failed(parser.parseType(size_type))) {
+ return parser.emitError(parser.getCurrentLocation(),
+ "expected ':' and type after size operand");
+ }
+ } else {
+ return parser.emitError(parser.getCurrentLocation(),
+ "expected dyn_groupprivate_size operand");
+ }
+ }
+
+ return success();
+}
+
+static void
+printDynGroupprivateClause(OpAsmPrinter &printer, Operation *op,
+ DynGroupprivateModifierAttr modifierFirst,
+ DynGroupprivateModifierAttr modifierSecond,
+ Value dynGroupprivateSize, Type size_type) {
+
+ bool needsComma = false;
+
+ if (modifierFirst) {
+ printer << modifierFirst.getValue();
+ needsComma = true;
+ }
+
+ if (modifierSecond) {
+ if (needsComma)
+ printer << ", ";
+ printer << modifierSecond.getValue();
+ needsComma = true;
+ }
+
+ if (dynGroupprivateSize) {
+ if (needsComma)
+ printer << " : ";
+ printer << dynGroupprivateSize << " : " << size_type;
+ }
+}
+
//===----------------------------------------------------------------------===//
// Parsers for operations including clauses that define entry block arguments.
//===----------------------------------------------------------------------===//
@@ -1951,6 +2112,8 @@ void TargetOp::build(OpBuilder &builder, OperationState &state,
clauses.mapVars, clauses.nowait, clauses.privateVars,
makeArrayAttr(ctx, clauses.privateSyms),
clauses.privateNeedsBarrier, clauses.threadLimit,
+ clauses.modifierFirst, clauses.modifierSecond,
+ clauses.dynGroupprivateSize,
/*private_maps=*/nullptr);
}
@@ -1965,6 +2128,12 @@ LogicalResult TargetOp::verify() {
if (failed(verifyMapClause(*this, getMapVars())))
return failure();
+ // check dyn_groupprivate clause restrictions
+ if (failed(verifyDynGroupprivateClause(*this, getModifierFirstAttr(),
+ getModifierSecondAttr(),
+ getDynGroupprivateSize())))
+ return failure();
+
return verifyPrivateVarsMapping(*this);
}
@@ -2339,8 +2508,9 @@ void TeamsOp::build(OpBuilder &builder, OperationState &state,
/*private_needs_barrier=*/nullptr, clauses.reductionMod,
clauses.reductionVars,
makeDenseBoolArrayAttr(ctx, clauses.reductionByref),
- makeArrayAttr(ctx, clauses.reductionSyms),
- clauses.threadLimit);
+ makeArrayAttr(ctx, clauses.reductionSyms), clauses.threadLimit,
+ clauses.modifierFirst, clauses.modifierSecond,
+ clauses.dynGroupprivateSize);
}
LogicalResult TeamsOp::verify() {
@@ -2371,6 +2541,12 @@ LogicalResult TeamsOp::verify() {
return emitError(
"expected equal sizes for allocate and allocator variables");
+ // check dyn_groupprivate clause restrictions
+ if (failed(verifyDynGroupprivateClause(op, getModifierFirstAttr(),
+ getModifierSecondAttr(),
+ getDynGroupprivateSize())))
+ return failure();
+
return verifyReductionVarList(*this, getReductionSyms(), getReductionVars(),
getReductionByref());
}
diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index eb96cb211fdd5..a25098d22a8f3 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -383,6 +383,10 @@ static LogicalResult checkImplementationStatus(Operation &op) {
if (op.getUntied())
result = todo("untied");
};
+ auto checkDynGroupprivate = [&todo](auto op, LogicalResult &result) {
+ if (op.getDynGroupprivateSize())
+ result = todo("dyn_groupprivate");
+ };
LogicalResult result = success();
llvm::TypeSwitch<Operation &>(op)
@@ -408,6 +412,7 @@ static LogicalResult checkImplementationStatus(Operation &op) {
.Case([&](omp::TeamsOp op) {
checkAllocate(op, result);
checkPrivate(op, result);
+ checkDynGroupprivate(op, result);
})
.Case([&](omp::TaskOp op) {
checkAllocate(op, result);
@@ -451,6 +456,7 @@ static LogicalResult checkImplementationStatus(Operation &op) {
checkInReduction(op, result);
checkIsDevicePtr(op, result);
checkPrivate(op, result);
+ checkDynGroupprivate(op, result);
})
.Default([](Operation &) {
// Assume all clauses for an operation can be translated unless they are
diff --git a/mlir/test/Dialect/OpenMP/invalid.mlir b/mlir/test/Dialect/OpenMP/invalid.mlir
index 5088f2dfa7d7a..0d77ddc530cbb 100644
--- a/mlir/test/Dialect/OpenMP/invalid.mlir
+++ b/mlir/test/Dialect/OpenMP/invalid.mlir
@@ -1415,7 +1415,7 @@ func.func @omp_teams_allocate(%data_var : memref<i32>) {
// expected-error @below {{expected equal sizes for allocate and allocator variables}}
"omp.teams" (%data_var) ({
omp.terminator
- }) {operandSegmentSizes = array<i32: 1,0,0,0,0,0,0,0>} : (memref<i32>) -> ()
+ }) {operandSegmentSizes = array<i32: 1,0,0,0,0,0,0,0,0>} : (memref<i32>) -> ()
omp.terminator
}
return
@@ -1428,7 +1428,7 @@ func.func @omp_teams_num_teams1(%lb : i32) {
// expected-error @below {{expected num_teams upper bound to be defined if the lower bound is defined}}
"omp.teams" (%lb) ({
omp.terminator
- }) {operandSegmentSizes = array<i32: 0,0,0,1,0,0,0,0>} : (i32) -> ()
+ }) {operandSegmentSizes = array<i32: 0,0,0,1,0,0,0,0,0>} : (i32) -> ()
omp.terminator
}
return
@@ -1449,6 +1449,26 @@ func.func @omp_teams_num_teams2(%lb : i32, %ub : i16) {
// -----
+func.func @test_teams_dyn_groupprivate_errors_1(%dyn_size: i32) {
+ // expected-error @below {{duplicate dyn_groupprivate modifier 'strict'}}
+ omp.teams dyn_groupprivate(strict, strict : %dyn_size : i32) {
+ omp.terminator
+ }
+ return
+}
+
+// -----
+
+func.func @test_teams_dyn_groupprivate_errors_2(%dyn_size: i32) {
+ // expected-error @below {{incompatible dyn_groupprivate modifiers: 'strict' and 'fallback' cannot be used together}}
+ omp.teams dyn_groupprivate(strict, fallback : %dyn_size : i32) {
+ omp.terminator
+ }
+ return
+}
+
+// -----
+
func.func @omp_sections(%data_var : memref<i32>) -> () {
// expected-error @below {{expected equal sizes for allocate and allocator variables}}
"omp.sections" (%data_var) ({
@@ -2435,12 +2455,26 @@ func.func @omp_target_depend(%data_var: memref<i32>) {
// expected-error @below {{op expected as many depend values as depend variables}}
"omp.target"(%data_var) ({
"omp.terminator"() : () -> ()
- }) {depend_kinds = [], operandSegmentSizes = array<i32: 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0>} : (memref<i32>) -> ()
+ }) {depend_kinds = [], operandSegmentSizes = array<i32: 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>} : (memref<i32>) -> ()
"func.return"() : () -> ()
}
// -----
+func.func @test_target_dyn_groupprivate_errors(%dyn_size: i32) {
+ // expected-error @below {{duplicate dyn_groupprivate modifier 'strict'}}
+ omp.target dyn_groupprivate(strict, strict : %dyn_size : i32) {
+ omp.terminator
+ }
+ // expected-error @below {{incompatible dyn_groupprivate modifiers: 'strict' and 'fallback' cannot be used together}}
+ omp.target dyn_groupprivate(strict, fallback : %dyn_size : i32) {
+ omp.terminator
+ }
+ return
+}
+
+// -----
+
func.func @omp_distribute_schedule(%chunk_size : i32, %lb : i32, %ub : i32, %step : i32) -> () {
// expected-error @below {{op chunk size set without dist_schedule_static being present}}
"omp.distribute"(%chunk_size) <{operandSegmentSizes = array<i32: 0, 0, 1, 0>}> ({
diff --git a/mlir/test/Dialect/OpenMP/ops.mlir b/mlir/test/Dialect/OpenMP/ops.mlir
index 8c846cde1a3ca..64be91937ebad 100644
--- a/mlir/test/Dialect/OpenMP/ops.mlir
+++ b/mlir/test/Dialect/OpenMP/ops.mlir
@@ -769,7 +769,7 @@ func.func @omp_target(%if_cond : i1, %device : si32, %num_threads : i32, %devic
"omp.target"(%device, %if_cond, %num_threads) ({
// CHECK: omp.terminator
omp.terminator
- }) {nowait, operandSegmentSizes = array<i32: 0,0,0,1,0,0,1,0,0,0,0,1>} : ( si32, i1, i32 ) -> ()
+ }) {nowait, operandSegmentSizes = array<i32: 0,0,0,1,0,0,1,0,0,0,0,1,0>} : ( si32, i1, i32 ) -> ()
// Test with optional map clause.
// CHECK: %[[MAP_A:.*]] = omp.map.info var_ptr(%[[VAL_1:.*]] : memref<?xi32>, tensor<?xi32>) map_clauses(always, to) capture(ByRef) -> memref<?xi32> {name = ""}
@@ -1022,7 +1022,7 @@ func.func @parallel_wsloop_reduction(%lb : index, %ub : index, %step : index) {
// CHECK-LABEL: omp_teams
func.func @omp_teams(%lb : i32, %ub : i32, %if_cond : i1, %num_threads : i32,
- %data_var : memref<i32>) -> () {
+ %dyn_size : i32, %data_var : memref<i32>) -> () {
// Test nesting inside of omp.target
omp.target {
// CHECK: omp.teams
@@ -1092,6 +1092,13 @@ func.func @omp_teams(%lb : i32, %ub : i32, %if_cond : i1, %num_threads : i32,
// CHECK: omp.terminator
omp.terminator
}
+
+ // Test dyn_groupprivate
+ // CHECK: omp.teams dyn_groupprivate(cgroup, strict : %{{.+}} : i32)
+ omp.teams dyn_groupprivate(cgroup, strict : %dyn_size : i32) {
+ // CHECK: omp.terminator
+ omp.terminator
+ }
return
}
@@ -2153,6 +2160,28 @@ func.func @omp_target_depend(%arg0: memref<i32>, %arg1: memref<i32>) {
return
}
+// CHECK-LABEL: @omp_target_dyn_groupprivate
+func.func @omp_target_dyn_groupprivate(%dyn_size: i32, %large_size: i64) {
+ // CHECK: omp.target dyn_groupprivate(strict, cgroup : %{{.*}} : i32)
+ omp.target dyn_groupprivate(strict, cgroup : %dyn_size : i32) {
+ // CHECK: omp.terminator
+ omp.terminator
+ }
+ // CHECK: omp.target dyn_groupprivate(cgroup : %{{.*}} : i64)
+ omp.target dyn_groupprivate(cgroup : %large_size : i64) {
+ omp.terminator
+ }
+ // CHECK: omp.target dyn_groupprivate(fallback, cgroup : %{{.*}} : i32)
+ omp.target dyn_groupprivate(fallback, cgroup : %dyn_size : i32) {
+ omp.terminator
+ }
+ // CHECK: omp.target dyn_groupprivate(%{{.*}} : i64)
+ omp.target dyn_groupprivate(%large_size : i64) {
+ omp.terminator
+ }
+ return
+}
+
func.func @omp_threadprivate() {
%0 = arith.constant 1 : i32
%1 = arith.constant 2 : i32
|
@llvm/pr-subscribers-mlir-openmp Author: Chaitanya (skc7) Changes
This PR enables dyn_groupprivate clause in openmp mlir dialect and adds it to Teams and Target ops. Full diff: https://github.com/llvm/llvm-project/pull/153562.diff 7 Files Affected:
diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
index 311c57fb4446c..3c55b860d8a42 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
@@ -1470,4 +1470,47 @@ class OpenMP_UseDevicePtrClauseSkip<
def OpenMP_UseDevicePtrClause : OpenMP_UseDevicePtrClauseSkip<>;
+//===----------------------------------------------------------------------===//
+// V6.1 `dyn_groupprivate` clause
+//===----------------------------------------------------------------------===//
+
+class OpenMP_DynGroupprivateClauseSkip<
+ bit traits = false, bit arguments = false, bit assemblyFormat = false,
+ bit description = false, bit extraClassDeclaration = false
+ > : OpenMP_Clause<traits, arguments, assemblyFormat, description,
+ extraClassDeclaration> {
+
+ let arguments = (ins
+ OptionalAttr<DynGroupprivateModifierAttr>:$modifier_first,
+ OptionalAttr<DynGroupprivateModifierAttr>:$modifier_second,
+ Optional<AnyInteger>:$dyn_groupprivate_size
+ );
+
+ let description = [{
+ The `dyn_groupprivate` clause allows you to dynamically allocate group-private
+ memory in OpenMP parallel regions, specifically for `target` and `teams` directives.
+ This clause enables runtime-sized private memory allocation and applicable to
+ target and teams ops.
+
+ Syntax:
+ ```
+ dyn_groupprivate(modifier_first ,modifier_second : dyn_groupprivate_size)
+ ```
+
+ Example:
+ ```
+ omp.target dyn_groupprivate(strict, cgroup : %dyn_groupprivate_size : i32)
+ ```
+ }];
+
+ let optAssemblyFormat = [{
+ `dyn_groupprivate` `(`
+ custom<DynGroupprivateClause>($modifier_first, $modifier_second,
+ $dyn_groupprivate_size, type($dyn_groupprivate_size))
+ `)`
+ }];
+}
+
+def OpenMP_DynGroupprivateClause : OpenMP_DynGroupprivateClauseSkip<>;
+
#endif // OPENMP_CLAUSES
diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPEnums.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPEnums.td
index c080c3fac87d4..9869615f5e885 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPEnums.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPEnums.td
@@ -293,4 +293,26 @@ def AllocatorHandle : OpenMP_I32EnumAttr<
]>;
def AllocatorHandleAttr : OpenMP_EnumAttr<AllocatorHandle, "allocator_handle">;
+
+//===----------------------------------------------------------------------===//
+// dyn_groupprivate enum.
+//===----------------------------------------------------------------------===//
+
+def DynGroupprivateCGroup : I32EnumAttrCase<"cgroup", 0>;
+def DynGroupprivateStrict : I32EnumAttrCase<"strict", 1>;
+def DynGroupprivateFallback : I32EnumAttrCase<"fallback", 2>;
+
+def DynGroupprivateModifier : OpenMP_I32EnumAttr<
+ "DynGroupprivateModifier",
+ "dyn_groupprivate modifier", [
+ DynGroupprivateCGroup,
+ DynGroupprivateStrict,
+ DynGroupprivateFallback
+ ]>;
+
+def DynGroupprivateModifierAttr : OpenMP_EnumAttr<DynGroupprivateModifier,
+ "dyn_groupprivate_modifier"> {
+ let assemblyFormat = "`(` $value `)`";
+}
+
#endif // OPENMP_ENUMS
diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
index be114ea4fb631..315233102be47 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
@@ -241,7 +241,8 @@ def TeamsOp : OpenMP_Op<"teams", traits = [
AttrSizedOperandSegments, RecursiveMemoryEffects, OutlineableOpenMPOpInterface
], clauses = [
OpenMP_AllocateClause, OpenMP_IfClause, OpenMP_NumTeamsClause,
- OpenMP_PrivateClause, OpenMP_ReductionClause, OpenMP_ThreadLimitClause
+ OpenMP_PrivateClause, OpenMP_ReductionClause, OpenMP_ThreadLimitClause,
+ OpenMP_DynGroupprivateClause
], singleRegion = true> {
let summary = "teams construct";
let description = [{
@@ -1464,7 +1465,7 @@ def TargetOp : OpenMP_Op<"target", traits = [
OpenMP_DeviceClause, OpenMP_HasDeviceAddrClause, OpenMP_HostEvalClause,
OpenMP_IfClause, OpenMP_InReductionClause, OpenMP_IsDevicePtrClause,
OpenMP_MapClauseSkip<assemblyFormat = true>, OpenMP_NowaitClause,
- OpenMP_PrivateClause, OpenMP_ThreadLimitClause
+ OpenMP_PrivateClause, OpenMP_ThreadLimitClause, OpenMP_DynGroupprivateClause
], singleRegion = true> {
let summary = "target construct";
let description = [{
diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
index c1c1767ef90b0..b70dbb7987208 100644
--- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
+++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
@@ -566,6 +566,167 @@ static void printNumTasksClause(OpAsmPrinter &p, Operation *op,
p, op, numTasksMod, numTasks, numTasksType, &stringifyClauseNumTasksType);
}
+//===----------------------------------------------------------------------===//
+// Parser, printer and verify for dyn_groupprivate Clause
+//===----------------------------------------------------------------------===//
+
+static LogicalResult verifyDynGroupprivateClause(
+ Operation *op, DynGroupprivateModifierAttr modifierFirst,
+ DynGroupprivateModifierAttr modifierSecond, Value dynGroupprivateSize) {
+
+ // Helper to get modifier name as string
+ auto getModifierName = [](DynGroupprivateModifier mod) -> StringRef {
+ switch (mod) {
+ case DynGroupprivateModifier::strict:
+ return "strict";
+ case DynGroupprivateModifier::cgroup:
+ return "cgroup";
+ case DynGroupprivateModifier::fallback:
+ return "fallback";
+ }
+ return "unknown";
+ };
+
+ // Check for duplicate modifiers
+ if (modifierFirst && modifierSecond &&
+ modifierFirst.getValue() == modifierSecond.getValue()) {
+ return op->emitOpError("duplicate dyn_groupprivate modifier '")
+ << getModifierName(modifierFirst.getValue()) << "'";
+ }
+
+ // Check for incompatible modifier combinations
+ if (modifierFirst && modifierSecond) {
+ auto m1 = modifierFirst.getValue();
+ auto m2 = modifierSecond.getValue();
+
+ // strict and fallback are incompatible
+ if ((m1 == DynGroupprivateModifier::strict &&
+ m2 == DynGroupprivateModifier::fallback) ||
+ (m1 == DynGroupprivateModifier::fallback &&
+ m2 == DynGroupprivateModifier::strict)) {
+ return op->emitOpError("incompatible dyn_groupprivate modifiers: '")
+ << getModifierName(m1) << "' and '" << getModifierName(m2)
+ << "' cannot be used together";
+ }
+ }
+
+ // Verify the size
+ if (dynGroupprivateSize) {
+ Type size_type = dynGroupprivateSize.getType();
+ // Check if the size type is an integer type
+ if (!size_type.isIntOrIndex()) {
+ return op->emitOpError(
+ "dyn_groupprivate size must be an integer type, got ")
+ << size_type;
+ }
+ }
+
+ return success();
+}
+
+static ParseResult parseDynGroupprivateClause(
+ OpAsmParser &parser, DynGroupprivateModifierAttr &modifierFirst,
+ DynGroupprivateModifierAttr &modifierSecond,
+ std::optional<OpAsmParser::UnresolvedOperand> &dynGroupprivateSize,
+ Type &size_type) {
+
+ bool hasModifiers = false;
+
+ // Parse first modifier if present
+ if (succeeded(parser.parseOptionalKeyword("strict"))) {
+ modifierFirst = DynGroupprivateModifierAttr::get(
+ parser.getContext(), DynGroupprivateModifier::strict);
+ hasModifiers = true;
+ } else if (succeeded(parser.parseOptionalKeyword("cgroup"))) {
+ modifierFirst = DynGroupprivateModifierAttr::get(
+ parser.getContext(), DynGroupprivateModifier::cgroup);
+ hasModifiers = true;
+ } else if (succeeded(parser.parseOptionalKeyword("fallback"))) {
+ modifierFirst = DynGroupprivateModifierAttr::get(
+ parser.getContext(), DynGroupprivateModifier::fallback);
+ hasModifiers = true;
+ }
+
+ // If first modifier found, check for comma and second modifier
+ if (hasModifiers && succeeded(parser.parseOptionalComma())) {
+ if (succeeded(parser.parseOptionalKeyword("strict"))) {
+ modifierSecond = DynGroupprivateModifierAttr::get(
+ parser.getContext(), DynGroupprivateModifier::strict);
+ } else if (succeeded(parser.parseOptionalKeyword("cgroup"))) {
+ modifierSecond = DynGroupprivateModifierAttr::get(
+ parser.getContext(), DynGroupprivateModifier::cgroup);
+ } else if (succeeded(parser.parseOptionalKeyword("fallback"))) {
+ modifierSecond = DynGroupprivateModifierAttr::get(
+ parser.getContext(), DynGroupprivateModifier::fallback);
+ } else {
+ return parser.emitError(parser.getCurrentLocation(),
+ "expected modifier after comma");
+ }
+ }
+
+ // Parse colon and size if modifiers were present, or just try to parse
+ // operand
+ if (hasModifiers) {
+ // Modifiers present, expect colon
+ if (failed(parser.parseColon())) {
+ return parser.emitError(parser.getCurrentLocation(),
+ "expected ':' after modifiers");
+ }
+
+ // Parse operand and type
+ OpAsmParser::UnresolvedOperand operand;
+ if (succeeded(parser.parseOperand(operand))) {
+ dynGroupprivateSize = operand;
+ if (failed(parser.parseColon()) || failed(parser.parseType(size_type))) {
+ return parser.emitError(parser.getCurrentLocation(),
+ "expected ':' and type after size operand");
+ }
+ }
+ } else {
+ // No modifiers, try to parse operand directly
+ OpAsmParser::UnresolvedOperand operand;
+ if (succeeded(parser.parseOperand(operand))) {
+ dynGroupprivateSize = operand;
+ if (failed(parser.parseColon()) || failed(parser.parseType(size_type))) {
+ return parser.emitError(parser.getCurrentLocation(),
+ "expected ':' and type after size operand");
+ }
+ } else {
+ return parser.emitError(parser.getCurrentLocation(),
+ "expected dyn_groupprivate_size operand");
+ }
+ }
+
+ return success();
+}
+
+static void
+printDynGroupprivateClause(OpAsmPrinter &printer, Operation *op,
+ DynGroupprivateModifierAttr modifierFirst,
+ DynGroupprivateModifierAttr modifierSecond,
+ Value dynGroupprivateSize, Type size_type) {
+
+ bool needsComma = false;
+
+ if (modifierFirst) {
+ printer << modifierFirst.getValue();
+ needsComma = true;
+ }
+
+ if (modifierSecond) {
+ if (needsComma)
+ printer << ", ";
+ printer << modifierSecond.getValue();
+ needsComma = true;
+ }
+
+ if (dynGroupprivateSize) {
+ if (needsComma)
+ printer << " : ";
+ printer << dynGroupprivateSize << " : " << size_type;
+ }
+}
+
//===----------------------------------------------------------------------===//
// Parsers for operations including clauses that define entry block arguments.
//===----------------------------------------------------------------------===//
@@ -1951,6 +2112,8 @@ void TargetOp::build(OpBuilder &builder, OperationState &state,
clauses.mapVars, clauses.nowait, clauses.privateVars,
makeArrayAttr(ctx, clauses.privateSyms),
clauses.privateNeedsBarrier, clauses.threadLimit,
+ clauses.modifierFirst, clauses.modifierSecond,
+ clauses.dynGroupprivateSize,
/*private_maps=*/nullptr);
}
@@ -1965,6 +2128,12 @@ LogicalResult TargetOp::verify() {
if (failed(verifyMapClause(*this, getMapVars())))
return failure();
+ // check dyn_groupprivate clause restrictions
+ if (failed(verifyDynGroupprivateClause(*this, getModifierFirstAttr(),
+ getModifierSecondAttr(),
+ getDynGroupprivateSize())))
+ return failure();
+
return verifyPrivateVarsMapping(*this);
}
@@ -2339,8 +2508,9 @@ void TeamsOp::build(OpBuilder &builder, OperationState &state,
/*private_needs_barrier=*/nullptr, clauses.reductionMod,
clauses.reductionVars,
makeDenseBoolArrayAttr(ctx, clauses.reductionByref),
- makeArrayAttr(ctx, clauses.reductionSyms),
- clauses.threadLimit);
+ makeArrayAttr(ctx, clauses.reductionSyms), clauses.threadLimit,
+ clauses.modifierFirst, clauses.modifierSecond,
+ clauses.dynGroupprivateSize);
}
LogicalResult TeamsOp::verify() {
@@ -2371,6 +2541,12 @@ LogicalResult TeamsOp::verify() {
return emitError(
"expected equal sizes for allocate and allocator variables");
+ // check dyn_groupprivate clause restrictions
+ if (failed(verifyDynGroupprivateClause(op, getModifierFirstAttr(),
+ getModifierSecondAttr(),
+ getDynGroupprivateSize())))
+ return failure();
+
return verifyReductionVarList(*this, getReductionSyms(), getReductionVars(),
getReductionByref());
}
diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index eb96cb211fdd5..a25098d22a8f3 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -383,6 +383,10 @@ static LogicalResult checkImplementationStatus(Operation &op) {
if (op.getUntied())
result = todo("untied");
};
+ auto checkDynGroupprivate = [&todo](auto op, LogicalResult &result) {
+ if (op.getDynGroupprivateSize())
+ result = todo("dyn_groupprivate");
+ };
LogicalResult result = success();
llvm::TypeSwitch<Operation &>(op)
@@ -408,6 +412,7 @@ static LogicalResult checkImplementationStatus(Operation &op) {
.Case([&](omp::TeamsOp op) {
checkAllocate(op, result);
checkPrivate(op, result);
+ checkDynGroupprivate(op, result);
})
.Case([&](omp::TaskOp op) {
checkAllocate(op, result);
@@ -451,6 +456,7 @@ static LogicalResult checkImplementationStatus(Operation &op) {
checkInReduction(op, result);
checkIsDevicePtr(op, result);
checkPrivate(op, result);
+ checkDynGroupprivate(op, result);
})
.Default([](Operation &) {
// Assume all clauses for an operation can be translated unless they are
diff --git a/mlir/test/Dialect/OpenMP/invalid.mlir b/mlir/test/Dialect/OpenMP/invalid.mlir
index 5088f2dfa7d7a..0d77ddc530cbb 100644
--- a/mlir/test/Dialect/OpenMP/invalid.mlir
+++ b/mlir/test/Dialect/OpenMP/invalid.mlir
@@ -1415,7 +1415,7 @@ func.func @omp_teams_allocate(%data_var : memref<i32>) {
// expected-error @below {{expected equal sizes for allocate and allocator variables}}
"omp.teams" (%data_var) ({
omp.terminator
- }) {operandSegmentSizes = array<i32: 1,0,0,0,0,0,0,0>} : (memref<i32>) -> ()
+ }) {operandSegmentSizes = array<i32: 1,0,0,0,0,0,0,0,0>} : (memref<i32>) -> ()
omp.terminator
}
return
@@ -1428,7 +1428,7 @@ func.func @omp_teams_num_teams1(%lb : i32) {
// expected-error @below {{expected num_teams upper bound to be defined if the lower bound is defined}}
"omp.teams" (%lb) ({
omp.terminator
- }) {operandSegmentSizes = array<i32: 0,0,0,1,0,0,0,0>} : (i32) -> ()
+ }) {operandSegmentSizes = array<i32: 0,0,0,1,0,0,0,0,0>} : (i32) -> ()
omp.terminator
}
return
@@ -1449,6 +1449,26 @@ func.func @omp_teams_num_teams2(%lb : i32, %ub : i16) {
// -----
+func.func @test_teams_dyn_groupprivate_errors_1(%dyn_size: i32) {
+ // expected-error @below {{duplicate dyn_groupprivate modifier 'strict'}}
+ omp.teams dyn_groupprivate(strict, strict : %dyn_size : i32) {
+ omp.terminator
+ }
+ return
+}
+
+// -----
+
+func.func @test_teams_dyn_groupprivate_errors_2(%dyn_size: i32) {
+ // expected-error @below {{incompatible dyn_groupprivate modifiers: 'strict' and 'fallback' cannot be used together}}
+ omp.teams dyn_groupprivate(strict, fallback : %dyn_size : i32) {
+ omp.terminator
+ }
+ return
+}
+
+// -----
+
func.func @omp_sections(%data_var : memref<i32>) -> () {
// expected-error @below {{expected equal sizes for allocate and allocator variables}}
"omp.sections" (%data_var) ({
@@ -2435,12 +2455,26 @@ func.func @omp_target_depend(%data_var: memref<i32>) {
// expected-error @below {{op expected as many depend values as depend variables}}
"omp.target"(%data_var) ({
"omp.terminator"() : () -> ()
- }) {depend_kinds = [], operandSegmentSizes = array<i32: 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0>} : (memref<i32>) -> ()
+ }) {depend_kinds = [], operandSegmentSizes = array<i32: 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>} : (memref<i32>) -> ()
"func.return"() : () -> ()
}
// -----
+func.func @test_target_dyn_groupprivate_errors(%dyn_size: i32) {
+ // expected-error @below {{duplicate dyn_groupprivate modifier 'strict'}}
+ omp.target dyn_groupprivate(strict, strict : %dyn_size : i32) {
+ omp.terminator
+ }
+ // expected-error @below {{incompatible dyn_groupprivate modifiers: 'strict' and 'fallback' cannot be used together}}
+ omp.target dyn_groupprivate(strict, fallback : %dyn_size : i32) {
+ omp.terminator
+ }
+ return
+}
+
+// -----
+
func.func @omp_distribute_schedule(%chunk_size : i32, %lb : i32, %ub : i32, %step : i32) -> () {
// expected-error @below {{op chunk size set without dist_schedule_static being present}}
"omp.distribute"(%chunk_size) <{operandSegmentSizes = array<i32: 0, 0, 1, 0>}> ({
diff --git a/mlir/test/Dialect/OpenMP/ops.mlir b/mlir/test/Dialect/OpenMP/ops.mlir
index 8c846cde1a3ca..64be91937ebad 100644
--- a/mlir/test/Dialect/OpenMP/ops.mlir
+++ b/mlir/test/Dialect/OpenMP/ops.mlir
@@ -769,7 +769,7 @@ func.func @omp_target(%if_cond : i1, %device : si32, %num_threads : i32, %devic
"omp.target"(%device, %if_cond, %num_threads) ({
// CHECK: omp.terminator
omp.terminator
- }) {nowait, operandSegmentSizes = array<i32: 0,0,0,1,0,0,1,0,0,0,0,1>} : ( si32, i1, i32 ) -> ()
+ }) {nowait, operandSegmentSizes = array<i32: 0,0,0,1,0,0,1,0,0,0,0,1,0>} : ( si32, i1, i32 ) -> ()
// Test with optional map clause.
// CHECK: %[[MAP_A:.*]] = omp.map.info var_ptr(%[[VAL_1:.*]] : memref<?xi32>, tensor<?xi32>) map_clauses(always, to) capture(ByRef) -> memref<?xi32> {name = ""}
@@ -1022,7 +1022,7 @@ func.func @parallel_wsloop_reduction(%lb : index, %ub : index, %step : index) {
// CHECK-LABEL: omp_teams
func.func @omp_teams(%lb : i32, %ub : i32, %if_cond : i1, %num_threads : i32,
- %data_var : memref<i32>) -> () {
+ %dyn_size : i32, %data_var : memref<i32>) -> () {
// Test nesting inside of omp.target
omp.target {
// CHECK: omp.teams
@@ -1092,6 +1092,13 @@ func.func @omp_teams(%lb : i32, %ub : i32, %if_cond : i1, %num_threads : i32,
// CHECK: omp.terminator
omp.terminator
}
+
+ // Test dyn_groupprivate
+ // CHECK: omp.teams dyn_groupprivate(cgroup, strict : %{{.+}} : i32)
+ omp.teams dyn_groupprivate(cgroup, strict : %dyn_size : i32) {
+ // CHECK: omp.terminator
+ omp.terminator
+ }
return
}
@@ -2153,6 +2160,28 @@ func.func @omp_target_depend(%arg0: memref<i32>, %arg1: memref<i32>) {
return
}
+// CHECK-LABEL: @omp_target_dyn_groupprivate
+func.func @omp_target_dyn_groupprivate(%dyn_size: i32, %large_size: i64) {
+ // CHECK: omp.target dyn_groupprivate(strict, cgroup : %{{.*}} : i32)
+ omp.target dyn_groupprivate(strict, cgroup : %dyn_size : i32) {
+ // CHECK: omp.terminator
+ omp.terminator
+ }
+ // CHECK: omp.target dyn_groupprivate(cgroup : %{{.*}} : i64)
+ omp.target dyn_groupprivate(cgroup : %large_size : i64) {
+ omp.terminator
+ }
+ // CHECK: omp.target dyn_groupprivate(fallback, cgroup : %{{.*}} : i32)
+ omp.target dyn_groupprivate(fallback, cgroup : %dyn_size : i32) {
+ omp.terminator
+ }
+ // CHECK: omp.target dyn_groupprivate(%{{.*}} : i64)
+ omp.target dyn_groupprivate(%large_size : i64) {
+ omp.terminator
+ }
+ return
+}
+
func.func @omp_threadprivate() {
%0 = arith.constant 1 : i32
%1 = arith.constant 2 : i32
|
dyn_groupprivate
clause allows to dynamically allocate group-private memory in OpenMP parallel regions, specifically fortarget
andteams
directives.This PR enables dyn_groupprivate clause in openmp mlir dialect and adds it to Teams and Target ops. Also includes parser, printer and verification for clause.