Skip to content

Commit d452e67

Browse files
authored
[flang][OpenMP] Enable tiling (#143715)
This patch enables tiling in flang. In MLIR tiling is handled by changing the the omp.loop_nest op to be able to represent both collapse and tiling, so the flang front-end will combine the nested constructs into a single MLIR op. The MLIR->LLVM-IR lowering of the LoopNestOp is enhanced to first do the tiling if present, then collapse.
1 parent 106eb46 commit d452e67

File tree

26 files changed

+514
-117
lines changed

26 files changed

+514
-117
lines changed

flang/include/flang/Lower/OpenMP.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,6 @@ void genOpenMPDeclarativeConstruct(AbstractConverter &,
8080
void genOpenMPSymbolProperties(AbstractConverter &converter,
8181
const pft::Variable &var);
8282

83-
int64_t getCollapseValue(const Fortran::parser::OmpClauseList &clauseList);
8483
void genThreadprivateOp(AbstractConverter &, const pft::Variable &);
8584
void genDeclareTargetIntGlobal(AbstractConverter &, const pft::Variable &);
8685
bool isOpenMPTargetConstruct(const parser::OpenMPConstruct &);

flang/lib/Lower/OpenMP/ClauseProcessor.cpp

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -273,10 +273,15 @@ bool ClauseProcessor::processCancelDirectiveName(
273273

274274
bool ClauseProcessor::processCollapse(
275275
mlir::Location currentLocation, lower::pft::Evaluation &eval,
276-
mlir::omp::LoopRelatedClauseOps &result,
276+
mlir::omp::LoopRelatedClauseOps &loopResult,
277+
mlir::omp::CollapseClauseOps &collapseResult,
277278
llvm::SmallVectorImpl<const semantics::Symbol *> &iv) const {
278-
return collectLoopRelatedInfo(converter, currentLocation, eval, clauses,
279-
result, iv);
279+
280+
int64_t numCollapse = collectLoopRelatedInfo(converter, currentLocation, eval,
281+
clauses, loopResult, iv);
282+
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
283+
collapseResult.collapseNumLoops = firOpBuilder.getI64IntegerAttr(numCollapse);
284+
return numCollapse > 1;
280285
}
281286

282287
bool ClauseProcessor::processDevice(lower::StatementContext &stmtCtx,
@@ -522,6 +527,13 @@ bool ClauseProcessor::processProcBind(
522527
return false;
523528
}
524529

530+
bool ClauseProcessor::processTileSizes(
531+
lower::pft::Evaluation &eval, mlir::omp::LoopNestOperands &result) const {
532+
auto *ompCons{eval.getIf<parser::OpenMPConstruct>()};
533+
collectTileSizesFromOpenMPConstruct(ompCons, result.tileSizes, semaCtx);
534+
return !result.tileSizes.empty();
535+
}
536+
525537
bool ClauseProcessor::processSafelen(
526538
mlir::omp::SafelenClauseOps &result) const {
527539
if (auto *clause = findUniqueClause<omp::clause::Safelen>()) {

flang/lib/Lower/OpenMP/ClauseProcessor.h

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,8 @@ class ClauseProcessor {
6363
mlir::omp::CancelDirectiveNameClauseOps &result) const;
6464
bool
6565
processCollapse(mlir::Location currentLocation, lower::pft::Evaluation &eval,
66-
mlir::omp::LoopRelatedClauseOps &result,
66+
mlir::omp::LoopRelatedClauseOps &loopResult,
67+
mlir::omp::CollapseClauseOps &collapseResult,
6768
llvm::SmallVectorImpl<const semantics::Symbol *> &iv) const;
6869
bool processDevice(lower::StatementContext &stmtCtx,
6970
mlir::omp::DeviceClauseOps &result) const;
@@ -98,6 +99,8 @@ class ClauseProcessor {
9899
bool processPriority(lower::StatementContext &stmtCtx,
99100
mlir::omp::PriorityClauseOps &result) const;
100101
bool processProcBind(mlir::omp::ProcBindClauseOps &result) const;
102+
bool processTileSizes(lower::pft::Evaluation &eval,
103+
mlir::omp::LoopNestOperands &result) const;
101104
bool processSafelen(mlir::omp::SafelenClauseOps &result) const;
102105
bool processSchedule(lower::StatementContext &stmtCtx,
103106
mlir::omp::ScheduleClauseOps &result) const;

flang/lib/Lower/OpenMP/OpenMP.cpp

Lines changed: 10 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -503,7 +503,7 @@ static void processHostEvalClauses(lower::AbstractConverter &converter,
503503
[[fallthrough]];
504504
case OMPD_distribute:
505505
case OMPD_distribute_simd:
506-
cp.processCollapse(loc, eval, hostInfo->ops, hostInfo->iv);
506+
cp.processCollapse(loc, eval, hostInfo->ops, hostInfo->ops, hostInfo->iv);
507507
break;
508508

509509
case OMPD_teams:
@@ -522,7 +522,7 @@ static void processHostEvalClauses(lower::AbstractConverter &converter,
522522
[[fallthrough]];
523523
case OMPD_target_teams_distribute:
524524
case OMPD_target_teams_distribute_simd:
525-
cp.processCollapse(loc, eval, hostInfo->ops, hostInfo->iv);
525+
cp.processCollapse(loc, eval, hostInfo->ops, hostInfo->ops, hostInfo->iv);
526526
cp.processNumTeams(stmtCtx, hostInfo->ops);
527527
break;
528528

@@ -533,7 +533,7 @@ static void processHostEvalClauses(lower::AbstractConverter &converter,
533533
cp.processNumTeams(stmtCtx, hostInfo->ops);
534534
[[fallthrough]];
535535
case OMPD_loop:
536-
cp.processCollapse(loc, eval, hostInfo->ops, hostInfo->iv);
536+
cp.processCollapse(loc, eval, hostInfo->ops, hostInfo->ops, hostInfo->iv);
537537
break;
538538

539539
case OMPD_teams_workdistribute:
@@ -1569,9 +1569,10 @@ genLoopNestClauses(lower::AbstractConverter &converter,
15691569

15701570
HostEvalInfo *hostEvalInfo = getHostEvalInfoStackTop(converter);
15711571
if (!hostEvalInfo || !hostEvalInfo->apply(clauseOps, iv))
1572-
cp.processCollapse(loc, eval, clauseOps, iv);
1572+
cp.processCollapse(loc, eval, clauseOps, clauseOps, iv);
15731573

15741574
clauseOps.loopInclusive = converter.getFirOpBuilder().getUnitAttr();
1575+
cp.processTileSizes(eval, clauseOps);
15751576
}
15761577

15771578
static void genLoopClauses(
@@ -1948,9 +1949,9 @@ static mlir::omp::LoopNestOp genLoopNestOp(
19481949
return llvm::SmallVector<const semantics::Symbol *>(iv);
19491950
};
19501951

1951-
auto *nestedEval =
1952-
getCollapsedLoopEval(eval, getCollapseValue(item->clauses));
1953-
1952+
uint64_t nestValue = getCollapseValue(item->clauses);
1953+
nestValue = nestValue < iv.size() ? iv.size() : nestValue;
1954+
auto *nestedEval = getCollapsedLoopEval(eval, nestValue);
19541955
return genOpWithBody<mlir::omp::LoopNestOp>(
19551956
OpWithBodyGenInfo(converter, symTable, semaCtx, loc, *nestedEval,
19561957
directive)
@@ -3843,8 +3844,8 @@ static void genOMP(lower::AbstractConverter &converter, lower::SymMap &symTable,
38433844
parser::omp::GetOmpDirectiveName(*ompNestedLoopCons).v;
38443845
switch (nestedDirective) {
38453846
case llvm::omp::Directive::OMPD_tile:
3846-
// Emit the omp.loop_nest with annotation for tiling
3847-
genOMP(converter, symTable, semaCtx, eval, ompNestedLoopCons->value());
3847+
// Skip OMPD_tile since the tile sizes will be retrieved when
3848+
// generating the omp.loop_nest op.
38483849
break;
38493850
default: {
38503851
unsigned version = semaCtx.langOptions().OpenMPVersion;
@@ -3957,18 +3958,6 @@ void Fortran::lower::genOpenMPSymbolProperties(
39573958
lower::genDeclareTargetIntGlobal(converter, var);
39583959
}
39593960

3960-
int64_t
3961-
Fortran::lower::getCollapseValue(const parser::OmpClauseList &clauseList) {
3962-
for (const parser::OmpClause &clause : clauseList.v) {
3963-
if (const auto &collapseClause =
3964-
std::get_if<parser::OmpClause::Collapse>(&clause.u)) {
3965-
const auto *expr = semantics::GetExpr(collapseClause->v);
3966-
return evaluate::ToInt64(*expr).value();
3967-
}
3968-
}
3969-
return 1;
3970-
}
3971-
39723961
void Fortran::lower::genThreadprivateOp(lower::AbstractConverter &converter,
39733962
const lower::pft::Variable &var) {
39743963
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();

flang/lib/Lower/OpenMP/Utils.cpp

Lines changed: 88 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
#include "Utils.h"
1414

1515
#include "ClauseFinder.h"
16+
#include "flang/Evaluate/fold.h"
1617
#include "flang/Lower/OpenMP/Clauses.h"
1718
#include <flang/Lower/AbstractConverter.h>
1819
#include <flang/Lower/ConvertType.h>
@@ -24,11 +25,32 @@
2425
#include <flang/Parser/parse-tree.h>
2526
#include <flang/Parser/tools.h>
2627
#include <flang/Semantics/tools.h>
28+
#include <flang/Semantics/type.h>
2729
#include <flang/Utils/OpenMP.h>
2830
#include <llvm/Support/CommandLine.h>
2931

3032
#include <iterator>
3133

34+
template <typename T>
35+
Fortran::semantics::MaybeIntExpr
36+
EvaluateIntExpr(Fortran::semantics::SemanticsContext &context, const T &expr) {
37+
if (Fortran::semantics::MaybeExpr maybeExpr{
38+
Fold(context.foldingContext(), AnalyzeExpr(context, expr))}) {
39+
if (auto *intExpr{
40+
Fortran::evaluate::UnwrapExpr<Fortran::semantics::SomeIntExpr>(
41+
*maybeExpr)}) {
42+
return std::move(*intExpr);
43+
}
44+
}
45+
return std::nullopt;
46+
}
47+
48+
template <typename T>
49+
std::optional<std::int64_t>
50+
EvaluateInt64(Fortran::semantics::SemanticsContext &context, const T &expr) {
51+
return Fortran::evaluate::ToInt64(EvaluateIntExpr(context, expr));
52+
}
53+
3254
llvm::cl::opt<bool> treatIndexAsSection(
3355
"openmp-treat-index-as-section",
3456
llvm::cl::desc("In the OpenMP data clauses treat `a(N)` as `a(N:N)`."),
@@ -577,12 +599,64 @@ static void convertLoopBounds(lower::AbstractConverter &converter,
577599
}
578600
}
579601

580-
bool collectLoopRelatedInfo(
602+
// Helper function that finds the sizes clause in a inner OMPD_tile directive
603+
// and passes the sizes clause to the callback function if found.
604+
static void processTileSizesFromOpenMPConstruct(
605+
const parser::OpenMPConstruct *ompCons,
606+
std::function<void(const parser::OmpClause::Sizes *)> processFun) {
607+
if (!ompCons)
608+
return;
609+
if (auto *ompLoop{std::get_if<parser::OpenMPLoopConstruct>(&ompCons->u)}) {
610+
const auto &nestedOptional =
611+
std::get<std::optional<parser::NestedConstruct>>(ompLoop->t);
612+
assert(nestedOptional.has_value() &&
613+
"Expected a DoConstruct or OpenMPLoopConstruct");
614+
const auto *innerConstruct =
615+
std::get_if<common::Indirection<parser::OpenMPLoopConstruct>>(
616+
&(nestedOptional.value()));
617+
if (innerConstruct) {
618+
const auto &innerLoopDirective = innerConstruct->value();
619+
const auto &innerBegin =
620+
std::get<parser::OmpBeginLoopDirective>(innerLoopDirective.t);
621+
const auto &innerDirective =
622+
std::get<parser::OmpLoopDirective>(innerBegin.t).v;
623+
624+
if (innerDirective == llvm::omp::Directive::OMPD_tile) {
625+
// Get the size values from parse tree and convert to a vector.
626+
const auto &innerClauseList{
627+
std::get<parser::OmpClauseList>(innerBegin.t)};
628+
for (const auto &clause : innerClauseList.v) {
629+
if (const auto tclause{
630+
std::get_if<parser::OmpClause::Sizes>(&clause.u)}) {
631+
processFun(tclause);
632+
break;
633+
}
634+
}
635+
}
636+
}
637+
}
638+
}
639+
640+
/// Populates the sizes vector with values if the given OpenMPConstruct
641+
/// contains a loop construct with an inner tiling construct.
642+
void collectTileSizesFromOpenMPConstruct(
643+
const parser::OpenMPConstruct *ompCons,
644+
llvm::SmallVectorImpl<int64_t> &tileSizes,
645+
Fortran::semantics::SemanticsContext &semaCtx) {
646+
processTileSizesFromOpenMPConstruct(
647+
ompCons, [&](const parser::OmpClause::Sizes *tclause) {
648+
for (auto &tval : tclause->v)
649+
if (const auto v{EvaluateInt64(semaCtx, tval)})
650+
tileSizes.push_back(*v);
651+
});
652+
}
653+
654+
int64_t collectLoopRelatedInfo(
581655
lower::AbstractConverter &converter, mlir::Location currentLocation,
582656
lower::pft::Evaluation &eval, const omp::List<omp::Clause> &clauses,
583657
mlir::omp::LoopRelatedClauseOps &result,
584658
llvm::SmallVectorImpl<const semantics::Symbol *> &iv) {
585-
bool found = false;
659+
int64_t numCollapse = 1;
586660
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
587661

588662
// Collect the loops to collapse.
@@ -595,9 +669,19 @@ bool collectLoopRelatedInfo(
595669
if (auto *clause =
596670
ClauseFinder::findUniqueClause<omp::clause::Collapse>(clauses)) {
597671
collapseValue = evaluate::ToInt64(clause->v).value();
598-
found = true;
672+
numCollapse = collapseValue;
673+
}
674+
675+
// Collect sizes from tile directive if present.
676+
std::int64_t sizesLengthValue = 0l;
677+
if (auto *ompCons{eval.getIf<parser::OpenMPConstruct>()}) {
678+
processTileSizesFromOpenMPConstruct(
679+
ompCons, [&](const parser::OmpClause::Sizes *tclause) {
680+
sizesLengthValue = tclause->v.size();
681+
});
599682
}
600683

684+
collapseValue = std::max(collapseValue, sizesLengthValue);
601685
std::size_t loopVarTypeSize = 0;
602686
do {
603687
lower::pft::Evaluation *doLoop =
@@ -631,7 +715,7 @@ bool collectLoopRelatedInfo(
631715

632716
convertLoopBounds(converter, currentLocation, result, loopVarTypeSize);
633717

634-
return found;
718+
return numCollapse;
635719
}
636720

637721
} // namespace omp

flang/lib/Lower/OpenMP/Utils.h

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -159,12 +159,17 @@ void genObjectList(const ObjectList &objects,
159159
void lastprivateModifierNotSupported(const omp::clause::Lastprivate &lastp,
160160
mlir::Location loc);
161161

162-
bool collectLoopRelatedInfo(
162+
int64_t collectLoopRelatedInfo(
163163
lower::AbstractConverter &converter, mlir::Location currentLocation,
164164
lower::pft::Evaluation &eval, const omp::List<omp::Clause> &clauses,
165165
mlir::omp::LoopRelatedClauseOps &result,
166166
llvm::SmallVectorImpl<const semantics::Symbol *> &iv);
167167

168+
void collectTileSizesFromOpenMPConstruct(
169+
const parser::OpenMPConstruct *ompCons,
170+
llvm::SmallVectorImpl<int64_t> &tileSizes,
171+
Fortran::semantics::SemanticsContext &semaCtx);
172+
168173
} // namespace omp
169174
} // namespace lower
170175
} // namespace Fortran

0 commit comments

Comments
 (0)