Skip to content

[flang][acc] Lower do and do concurrent loops specially in acc regions #149614

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 21 additions & 1 deletion flang/include/flang/Lower/OpenACC.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ struct ProcedureDesignator;

namespace parser {
struct AccClauseList;
struct DoConstruct;
struct OpenACCConstruct;
struct OpenACCDeclarativeConstruct;
struct OpenACCRoutineConstruct;
Expand All @@ -58,6 +59,7 @@ namespace lower {

class AbstractConverter;
class StatementContext;
class SymMap;

namespace pft {
struct Evaluation;
Expand Down Expand Up @@ -114,14 +116,32 @@ void attachDeclarePostDeallocAction(AbstractConverter &, fir::FirOpBuilder &,
void genOpenACCTerminator(fir::FirOpBuilder &, mlir::Operation *,
mlir::Location);

int64_t getLoopCountForCollapseAndTile(const Fortran::parser::AccClauseList &);
/// Used to obtain the number of contained loops to look for
/// since this is dependent on number of tile operands and collapse
/// clause.
uint64_t getLoopCountForCollapseAndTile(const Fortran::parser::AccClauseList &);

/// Checks whether the current insertion point is inside OpenACC loop.
bool isInOpenACCLoop(fir::FirOpBuilder &);

/// Checks whether the current insertion point is inside OpenACC compute construct.
bool isInsideOpenACCComputeConstruct(fir::FirOpBuilder &);

void setInsertionPointAfterOpenACCLoopIfInside(fir::FirOpBuilder &);

void genEarlyReturnInOpenACCLoop(fir::FirOpBuilder &, mlir::Location);

/// Generates an OpenACC loop from a do construct in order to
/// properly capture the loop bounds, parallelism determination mode,
/// and to privatize the loop variables.
/// When the conversion is rejected, nullptr is returned.
mlir::Operation *genOpenACCLoopFromDoConstruct(
AbstractConverter &converter,
Fortran::semantics::SemanticsContext &semanticsContext,
Fortran::lower::SymMap &localSymbols,
const Fortran::parser::DoConstruct &doConstruct,
pft::Evaluation &eval);

} // namespace lower
} // namespace Fortran

Expand Down
33 changes: 29 additions & 4 deletions flang/lib/Lower/Bridge.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2164,10 +2164,35 @@ class FirConverter : public Fortran::lower::AbstractConverter {
/// - structured and unstructured concurrent loops
void genFIR(const Fortran::parser::DoConstruct &doConstruct) {
setCurrentPositionAt(doConstruct);
// Collect loop nest information.
// Generate begin loop code directly for infinite and while loops.
Fortran::lower::pft::Evaluation &eval = getEval();
bool unstructuredContext = eval.lowerAsUnstructured();

// Loops with induction variables inside OpenACC compute constructs
// need special handling to ensure that the IVs are privatized.
if (Fortran::lower::isInsideOpenACCComputeConstruct(*builder)) {
mlir::Operation* loopOp = Fortran::lower::genOpenACCLoopFromDoConstruct(
*this, bridge.getSemanticsContext(), localSymbols,
doConstruct, eval);
bool success = loopOp != nullptr;
if (success) {
// Sanity check that the builder insertion point is inside the newly
// generated loop.
assert(
loopOp->getRegion(0).isAncestor(
builder->getInsertionPoint()->getBlock()->getParent()) &&
"builder insertion point is not inside the newly generated loop");

// Loop body code.
auto iter = eval.getNestedEvaluations().begin();
for (auto end = --eval.getNestedEvaluations().end(); iter != end; ++iter)
genFIR(*iter, unstructuredContext);
return;
}
// Fall back to normal loop handling.
}

// Collect loop nest information.
// Generate begin loop code directly for infinite and while loops.
Fortran::lower::pft::Evaluation &doStmtEval =
eval.getFirstNestedEvaluation();
auto *doStmt = doStmtEval.getIf<Fortran::parser::NonLabelDoStmt>();
Expand Down Expand Up @@ -3121,7 +3146,7 @@ class FirConverter : public Fortran::lower::AbstractConverter {
Fortran::lower::pft::Evaluation *curEval = &getEval();

if (accLoop || accCombined) {
int64_t loopCount;
uint64_t loopCount;
if (accLoop) {
const Fortran::parser::AccBeginLoopDirective &beginLoopDir =
std::get<Fortran::parser::AccBeginLoopDirective>(accLoop->t);
Expand All @@ -3139,7 +3164,7 @@ class FirConverter : public Fortran::lower::AbstractConverter {

if (curEval->lowerAsStructured()) {
curEval = &curEval->getFirstNestedEvaluation();
for (int64_t i = 1; i < loopCount; i++)
for (uint64_t i = 1; i < loopCount; i++)
curEval = &*std::next(curEval->getNestedEvaluations().begin());
}
}
Expand Down
Loading
Loading