Skip to content

[MLIR][OpenMP] Add lowering support for AUTOMAP modifier #151513

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions flang/include/flang/Lower/OpenMP.h
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ struct Variable;
struct OMPDeferredDeclareTargetInfo {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

At this point, this and DeclareTargetCapturePair structure overlap quite a bit. Would it make sense to combine both into the same structure or make this one inherit from the other?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I couldn't inherit from DeclareTargetCaptureInfo as it's defined inside flang/lib/Lower/OpenMP/Utils.h which is a private header, so left this as is for now.

mlir::omp::DeclareTargetCaptureClause declareTargetCaptureClause;
mlir::omp::DeclareTargetDeviceType declareTargetDeviceType;
bool automap = false;
const Fortran::semantics::Symbol &sym;
};

Expand Down
19 changes: 11 additions & 8 deletions flang/lib/Lower/OpenMP/ClauseProcessor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1179,12 +1179,13 @@ bool ClauseProcessor::processLinear(mlir::omp::LinearClauseOps &result) const {
}

bool ClauseProcessor::processLink(
llvm::SmallVectorImpl<DeclareTargetCapturePair> &result) const {
llvm::SmallVectorImpl<DeclareTargetCaptureInfo> &result) const {
return findRepeatableClause<omp::clause::Link>(
[&](const omp::clause::Link &clause, const parser::CharBlock &) {
// Case: declare target link(var1, var2)...
gatherFuncAndVarSyms(
clause.v, mlir::omp::DeclareTargetCaptureClause::link, result);
clause.v, mlir::omp::DeclareTargetCaptureClause::link, result,
/*automap=*/false);
});
}

Expand Down Expand Up @@ -1507,26 +1508,28 @@ bool ClauseProcessor::processTaskReduction(
}

bool ClauseProcessor::processTo(
llvm::SmallVectorImpl<DeclareTargetCapturePair> &result) const {
llvm::SmallVectorImpl<DeclareTargetCaptureInfo> &result) const {
return findRepeatableClause<omp::clause::To>(
[&](const omp::clause::To &clause, const parser::CharBlock &) {
// Case: declare target to(func, var1, var2)...
gatherFuncAndVarSyms(std::get<ObjectList>(clause.t),
mlir::omp::DeclareTargetCaptureClause::to, result);
mlir::omp::DeclareTargetCaptureClause::to, result,
/*automap=*/false);
});
}

bool ClauseProcessor::processEnter(
llvm::SmallVectorImpl<DeclareTargetCapturePair> &result) const {
llvm::SmallVectorImpl<DeclareTargetCaptureInfo> &result) const {
return findRepeatableClause<omp::clause::Enter>(
[&](const omp::clause::Enter &clause, const parser::CharBlock &source) {
mlir::Location currentLocation = converter.genLocation(source);
if (std::get<std::optional<omp::clause::Enter::Modifier>>(clause.t))
TODO(currentLocation, "Declare target enter AUTOMAP modifier");
bool automap =
std::get<std::optional<omp::clause::Enter::Modifier>>(clause.t)
.has_value();
// Case: declare target enter(func, var1, var2)...
gatherFuncAndVarSyms(std::get<ObjectList>(clause.t),
mlir::omp::DeclareTargetCaptureClause::enter,
result);
result, automap);
});
}

Expand Down
6 changes: 3 additions & 3 deletions flang/lib/Lower/OpenMP/ClauseProcessor.h
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ class ClauseProcessor {
bool processDepend(lower::SymMap &symMap, lower::StatementContext &stmtCtx,
mlir::omp::DependClauseOps &result) const;
bool
processEnter(llvm::SmallVectorImpl<DeclareTargetCapturePair> &result) const;
processEnter(llvm::SmallVectorImpl<DeclareTargetCaptureInfo> &result) const;
bool processIf(omp::clause::If::DirectiveNameModifier directiveName,
mlir::omp::IfClauseOps &result) const;
bool processInReduction(
Expand All @@ -129,7 +129,7 @@ class ClauseProcessor {
llvm::SmallVectorImpl<const semantics::Symbol *> &isDeviceSyms) const;
bool processLinear(mlir::omp::LinearClauseOps &result) const;
bool
processLink(llvm::SmallVectorImpl<DeclareTargetCapturePair> &result) const;
processLink(llvm::SmallVectorImpl<DeclareTargetCaptureInfo> &result) const;

// This method is used to process a map clause.
// The optional parameter mapSyms is used to store the original Fortran symbol
Expand All @@ -150,7 +150,7 @@ class ClauseProcessor {
bool processTaskReduction(
mlir::Location currentLocation, mlir::omp::TaskReductionClauseOps &result,
llvm::SmallVectorImpl<const semantics::Symbol *> &outReductionSyms) const;
bool processTo(llvm::SmallVectorImpl<DeclareTargetCapturePair> &result) const;
bool processTo(llvm::SmallVectorImpl<DeclareTargetCaptureInfo> &result) const;
bool processUseDeviceAddr(
lower::StatementContext &stmtCtx,
mlir::omp::UseDeviceAddrClauseOps &result,
Expand Down
47 changes: 22 additions & 25 deletions flang/lib/Lower/OpenMP/OpenMP.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -764,14 +764,14 @@ static void getDeclareTargetInfo(
lower::pft::Evaluation &eval,
const parser::OpenMPDeclareTargetConstruct &declareTargetConstruct,
mlir::omp::DeclareTargetOperands &clauseOps,
llvm::SmallVectorImpl<DeclareTargetCapturePair> &symbolAndClause) {
llvm::SmallVectorImpl<DeclareTargetCaptureInfo> &symbolAndClause) {
const auto &spec =
std::get<parser::OmpDeclareTargetSpecifier>(declareTargetConstruct.t);
if (const auto *objectList{parser::Unwrap<parser::OmpObjectList>(spec.u)}) {
ObjectList objects{makeObjects(*objectList, semaCtx)};
// Case: declare target(func, var1, var2)
gatherFuncAndVarSyms(objects, mlir::omp::DeclareTargetCaptureClause::to,
symbolAndClause);
symbolAndClause, /*automap=*/false);
} else if (const auto *clauseList{
parser::Unwrap<parser::OmpClauseList>(spec.u)}) {
List<Clause> clauses = makeClauses(*clauseList, semaCtx);
Expand Down Expand Up @@ -804,21 +804,20 @@ static void collectDeferredDeclareTargets(
llvm::SmallVectorImpl<lower::OMPDeferredDeclareTargetInfo>
&deferredDeclareTarget) {
mlir::omp::DeclareTargetOperands clauseOps;
llvm::SmallVector<DeclareTargetCapturePair> symbolAndClause;
llvm::SmallVector<DeclareTargetCaptureInfo> symbolAndClause;
getDeclareTargetInfo(converter, semaCtx, eval, declareTargetConstruct,
clauseOps, symbolAndClause);
// Return the device type only if at least one of the targets for the
// directive is a function or subroutine
mlir::ModuleOp mod = converter.getFirOpBuilder().getModule();

for (const DeclareTargetCapturePair &symClause : symbolAndClause) {
mlir::Operation *op = mod.lookupSymbol(
converter.mangleName(std::get<const semantics::Symbol &>(symClause)));
for (const DeclareTargetCaptureInfo &symClause : symbolAndClause) {
mlir::Operation *op =
mod.lookupSymbol(converter.mangleName(symClause.symbol));

if (!op) {
deferredDeclareTarget.push_back({std::get<0>(symClause),
clauseOps.deviceType,
std::get<1>(symClause)});
deferredDeclareTarget.push_back({symClause.clause, clauseOps.deviceType,
symClause.automap, symClause.symbol});
}
}
}
Expand All @@ -829,16 +828,16 @@ getDeclareTargetFunctionDevice(
lower::pft::Evaluation &eval,
const parser::OpenMPDeclareTargetConstruct &declareTargetConstruct) {
mlir::omp::DeclareTargetOperands clauseOps;
llvm::SmallVector<DeclareTargetCapturePair> symbolAndClause;
llvm::SmallVector<DeclareTargetCaptureInfo> symbolAndClause;
getDeclareTargetInfo(converter, semaCtx, eval, declareTargetConstruct,
clauseOps, symbolAndClause);

// Return the device type only if at least one of the targets for the
// directive is a function or subroutine
mlir::ModuleOp mod = converter.getFirOpBuilder().getModule();
for (const DeclareTargetCapturePair &symClause : symbolAndClause) {
mlir::Operation *op = mod.lookupSymbol(
converter.mangleName(std::get<const semantics::Symbol &>(symClause)));
for (const DeclareTargetCaptureInfo &symClause : symbolAndClause) {
mlir::Operation *op =
mod.lookupSymbol(converter.mangleName(symClause.symbol));

if (mlir::isa_and_nonnull<mlir::func::FuncOp>(op))
return clauseOps.deviceType;
Expand Down Expand Up @@ -1055,7 +1054,7 @@ getImplicitMapTypeAndKind(fir::FirOpBuilder &firOpBuilder,
static void
markDeclareTarget(mlir::Operation *op, lower::AbstractConverter &converter,
mlir::omp::DeclareTargetCaptureClause captureClause,
mlir::omp::DeclareTargetDeviceType deviceType) {
mlir::omp::DeclareTargetDeviceType deviceType, bool automap) {
// TODO: Add support for program local variables with declare target applied
auto declareTargetOp = llvm::dyn_cast<mlir::omp::DeclareTargetInterface>(op);
if (!declareTargetOp)
Expand All @@ -1070,11 +1069,11 @@ markDeclareTarget(mlir::Operation *op, lower::AbstractConverter &converter,
if (declareTargetOp.isDeclareTarget()) {
if (declareTargetOp.getDeclareTargetDeviceType() != deviceType)
declareTargetOp.setDeclareTarget(mlir::omp::DeclareTargetDeviceType::any,
captureClause);
captureClause, automap);
return;
}

declareTargetOp.setDeclareTarget(deviceType, captureClause);
declareTargetOp.setDeclareTarget(deviceType, captureClause, automap);
}

//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -3540,25 +3539,23 @@ genOMP(lower::AbstractConverter &converter, lower::SymMap &symTable,
semantics::SemanticsContext &semaCtx, lower::pft::Evaluation &eval,
const parser::OpenMPDeclareTargetConstruct &declareTargetConstruct) {
mlir::omp::DeclareTargetOperands clauseOps;
llvm::SmallVector<DeclareTargetCapturePair> symbolAndClause;
llvm::SmallVector<DeclareTargetCaptureInfo> symbolAndClause;
mlir::ModuleOp mod = converter.getFirOpBuilder().getModule();
getDeclareTargetInfo(converter, semaCtx, eval, declareTargetConstruct,
clauseOps, symbolAndClause);

for (const DeclareTargetCapturePair &symClause : symbolAndClause) {
mlir::Operation *op = mod.lookupSymbol(
converter.mangleName(std::get<const semantics::Symbol &>(symClause)));
for (const DeclareTargetCaptureInfo &symClause : symbolAndClause) {
mlir::Operation *op =
mod.lookupSymbol(converter.mangleName(symClause.symbol));

// Some symbols are deferred until later in the module, these are handled
// upon finalization of the module for OpenMP inside of Bridge, so we simply
// skip for now.
if (!op)
continue;

markDeclareTarget(
op, converter,
std::get<mlir::omp::DeclareTargetCaptureClause>(symClause),
clauseOps.deviceType);
markDeclareTarget(op, converter, symClause.clause, clauseOps.deviceType,
symClause.automap);
}
}

Expand Down Expand Up @@ -4141,7 +4138,7 @@ bool Fortran::lower::markOpenMPDeferredDeclareTargetFunctions(
deviceCodeFound = true;

markDeclareTarget(op, converter, declTar.declareTargetCaptureClause,
devType);
devType, declTar.automap);
}

return deviceCodeFound;
Expand Down
5 changes: 3 additions & 2 deletions flang/lib/Lower/OpenMP/Utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -102,9 +102,10 @@ getIterationVariableSymbol(const lower::pft::Evaluation &eval) {

void gatherFuncAndVarSyms(
const ObjectList &objects, mlir::omp::DeclareTargetCaptureClause clause,
llvm::SmallVectorImpl<DeclareTargetCapturePair> &symbolAndClause) {
llvm::SmallVectorImpl<DeclareTargetCaptureInfo> &symbolAndClause,
bool automap) {
for (const Object &object : objects)
symbolAndClause.emplace_back(clause, *object.sym());
symbolAndClause.emplace_back(clause, *object.sym(), automap);
}

mlir::omp::MapInfoOp
Expand Down
14 changes: 11 additions & 3 deletions flang/lib/Lower/OpenMP/Utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,15 @@ class AbstractConverter;

namespace omp {

using DeclareTargetCapturePair =
std::pair<mlir::omp::DeclareTargetCaptureClause, const semantics::Symbol &>;
struct DeclareTargetCaptureInfo {
mlir::omp::DeclareTargetCaptureClause clause;
bool automap = false;
const semantics::Symbol &symbol;

DeclareTargetCaptureInfo(mlir::omp::DeclareTargetCaptureClause c,
const semantics::Symbol &s, bool a = false)
: clause(c), automap(a), symbol(s) {}
};

// A small helper structure for keeping track of a component members MapInfoOp
// and index data when lowering OpenMP map clauses. Keeps track of the
Expand Down Expand Up @@ -150,7 +157,8 @@ getIterationVariableSymbol(const lower::pft::Evaluation &eval);

void gatherFuncAndVarSyms(
const ObjectList &objects, mlir::omp::DeclareTargetCaptureClause clause,
llvm::SmallVectorImpl<DeclareTargetCapturePair> &symbolAndClause);
llvm::SmallVectorImpl<DeclareTargetCaptureInfo> &symbolAndClause,
bool automap = false);

int64_t getCollapseValue(const List<Clause> &clauses);

Expand Down
5 changes: 3 additions & 2 deletions flang/lib/Optimizer/OpenMP/FunctionFiltering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -95,8 +95,9 @@ class FunctionFilteringPass
return WalkResult::skip();
}
if (declareTargetOp)
declareTargetOp.setDeclareTarget(declareType,
omp::DeclareTargetCaptureClause::to);
declareTargetOp.setDeclareTarget(
declareType, omp::DeclareTargetCaptureClause::to,
declareTargetOp.getDeclareTargetAutomap());
}
return WalkResult::advance();
});
Expand Down
21 changes: 13 additions & 8 deletions flang/lib/Optimizer/OpenMP/MarkDeclareTarget.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ class MarkDeclareTargetPass

void markNestedFuncs(mlir::omp::DeclareTargetDeviceType parentDevTy,
mlir::omp::DeclareTargetCaptureClause parentCapClause,
mlir::Operation *currOp,
bool parentAutomap, mlir::Operation *currOp,
llvm::SmallPtrSet<mlir::Operation *, 16> visited) {
if (visited.contains(currOp))
return;
Expand All @@ -57,13 +57,16 @@ class MarkDeclareTargetPass
currentDt != mlir::omp::DeclareTargetDeviceType::any) {
current.setDeclareTarget(
mlir::omp::DeclareTargetDeviceType::any,
current.getDeclareTargetCaptureClause());
current.getDeclareTargetCaptureClause(),
current.getDeclareTargetAutomap());
}
} else {
current.setDeclareTarget(parentDevTy, parentCapClause);
current.setDeclareTarget(parentDevTy, parentCapClause,
parentAutomap);
}

markNestedFuncs(parentDevTy, parentCapClause, currFOp, visited);
markNestedFuncs(parentDevTy, parentCapClause, parentAutomap,
currFOp, visited);
}
}
}
Expand All @@ -81,7 +84,8 @@ class MarkDeclareTargetPass
llvm::SmallPtrSet<mlir::Operation *, 16> visited;
markNestedFuncs(declareTargetOp.getDeclareTargetDeviceType(),
declareTargetOp.getDeclareTargetCaptureClause(),
functionOp, visited);
declareTargetOp.getDeclareTargetAutomap(), functionOp,
visited);
}
}

Expand All @@ -92,9 +96,10 @@ class MarkDeclareTargetPass
// the contents of the device clause
getOperation()->walk([&](mlir::omp::TargetOp tarOp) {
llvm::SmallPtrSet<mlir::Operation *, 16> visited;
markNestedFuncs(mlir::omp::DeclareTargetDeviceType::nohost,
mlir::omp::DeclareTargetCaptureClause::to, tarOp,
visited);
markNestedFuncs(
/*parentDevTy=*/mlir::omp::DeclareTargetDeviceType::nohost,
/*parentCapClause=*/mlir::omp::DeclareTargetCaptureClause::to,
/*parentAutomap=*/false, tarOp, visited);
});
}
};
Expand Down
2 changes: 1 addition & 1 deletion flang/test/Lower/OpenMP/common-block-map.f90
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
!RUN: %flang_fc1 -emit-hlfir -fopenmp %s -o - | FileCheck %s

!CHECK: fir.global common @var_common_(dense<0> : vector<8xi8>) {{.*}} : !fir.array<8xi8>
!CHECK: fir.global common @var_common_link_(dense<0> : vector<8xi8>) {{{.*}} omp.declare_target = #omp.declaretarget<device_type = (any), capture_clause = (link)>} : !fir.array<8xi8>
!CHECK: fir.global common @var_common_link_(dense<0> : vector<8xi8>) {{{.*}} omp.declare_target = #omp.declaretarget<device_type = (any), capture_clause = (link), automap = false>} : !fir.array<8xi8>

!CHECK-LABEL: func.func @_QPmap_full_block
!CHECK: %[[CB_ADDR:.*]] = fir.address_of(@var_common_) : !fir.ref<!fir.array<8xi8>>
Expand Down
Loading