Skip to content

[flang][OpenMP] Enable tiling #143715

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 25 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
ae81104
Initial implementation of tiling.
jsjodin Mar 14, 2025
dc9506e
Fix tests and limit the nesting of construct to only tiling.
jsjodin Jun 10, 2025
99f125e
Enable stand-alone tiling, but it gives a warning and converting to s…
jsjodin Jun 10, 2025
20b9903
Add minimal test, remove debug print.
jsjodin Jun 11, 2025
6d56b99
Fix formatting
jsjodin Jun 13, 2025
69f0d94
Fix formatting
jsjodin Jun 14, 2025
13ddc87
Fix test.
jsjodin Jun 19, 2025
6056202
Add more mlir tests. Set collapse value when lowering from SCF to Ope…
jsjodin Jun 20, 2025
6ddacf5
Use llvm::SmallVector instead of std::stack
jsjodin Jun 20, 2025
b3a1da2
Improve test a bit to make sure IVs are used as expected.
jsjodin Jun 21, 2025
cf6b1d6
Fix comments to clarify canonicalization.
jsjodin Jun 21, 2025
9d886dd
Special handling of tile directive when dealing with start end end lo…
jsjodin Jun 21, 2025
323527d
Inline functions.
jsjodin Jun 21, 2025
d87797c
Remove debug code.
jsjodin Jun 23, 2025
8c70c88
Reuse loop op lowering, add comment.
jsjodin Jun 23, 2025
531fe64
Fix formatting.
jsjodin Jun 23, 2025
2923141
Remove curly braces.
jsjodin Jun 23, 2025
5939e18
Avoid attaching the sizes clause to the parent construct, instead fin…
jsjodin Jun 25, 2025
583e042
Fix formatting
jsjodin Jun 25, 2025
219595f
Fix unparse and add a test for nested loop constructs.
jsjodin Jun 26, 2025
356c166
Use more convenient function to get OpenMPLoopConstruct. Fix comments.
jsjodin Jun 26, 2025
955ccf6
Fix formatting.
jsjodin Jun 26, 2025
a0501d0
Fix merge problems related to the different representations used for …
jsjodin Aug 9, 2025
aea0767
Fix bugs introduced when merging.
jsjodin Aug 9, 2025
5dca0a3
Move include
jsjodin Aug 11, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion flang/include/flang/Lower/OpenMP.h
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,6 @@ void genOpenMPDeclarativeConstruct(AbstractConverter &,
void genOpenMPSymbolProperties(AbstractConverter &converter,
const pft::Variable &var);

int64_t getCollapseValue(const Fortran::parser::OmpClauseList &clauseList);
void genThreadprivateOp(AbstractConverter &, const pft::Variable &);
void genDeclareTargetIntGlobal(AbstractConverter &, const pft::Variable &);
bool isOpenMPTargetConstruct(const parser::OpenMPConstruct &);
Expand Down
79 changes: 58 additions & 21 deletions flang/lib/Lower/OpenMP/OpenMP.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@

using namespace Fortran::lower::omp;
using namespace Fortran::common::openmp;
using namespace Fortran::semantics;

//===----------------------------------------------------------------------===//
// Code generation helper functions
Expand Down Expand Up @@ -404,6 +405,7 @@ static void processHostEvalClauses(lower::AbstractConverter &converter,
return;

const parser::OmpClauseList *beginClauseList = nullptr;
const parser::OmpClauseList *middleClauseList = nullptr;
const parser::OmpClauseList *endClauseList = nullptr;
common::visit(
common::visitors{
Expand All @@ -418,6 +420,28 @@ static void processHostEvalClauses(lower::AbstractConverter &converter,
beginClauseList =
&std::get<parser::OmpClauseList>(beginDirective.t);

// For now we check if there is an inner OpenMPLoopConstruct, and
// extract the size clause from there
const auto &nestedOptional =
std::get<std::optional<parser::NestedConstruct>>(
ompConstruct.t);
assert(nestedOptional.has_value() &&
"Expected a DoConstruct or OpenMPLoopConstruct");
const auto *innerConstruct =
std::get_if<common::Indirection<parser::OpenMPLoopConstruct>>(
&(nestedOptional.value()));
if (innerConstruct) {
const auto &innerLoopConstruct = innerConstruct->value();
const auto &innerBegin =
std::get<parser::OmpBeginLoopDirective>(
innerLoopConstruct.t);
const auto &innerDirective =
std::get<parser::OmpLoopDirective>(innerBegin.t);
if (innerDirective.v == llvm::omp::Directive::OMPD_tile) {
middleClauseList =
&std::get<parser::OmpClauseList>(innerBegin.t);
}
}
if (auto &endDirective =
std::get<std::optional<parser::OmpEndLoopDirective>>(
ompConstruct.t)) {
Expand All @@ -431,6 +455,9 @@ static void processHostEvalClauses(lower::AbstractConverter &converter,
assert(beginClauseList && "expected begin directive");
clauses.append(makeClauses(*beginClauseList, semaCtx));

if (middleClauseList)
clauses.append(makeClauses(*middleClauseList, semaCtx));

if (endClauseList)
clauses.append(makeClauses(*endClauseList, semaCtx));
};
Expand Down Expand Up @@ -910,6 +937,7 @@ static void genLoopVars(
storeOp =
createAndSetPrivatizedLoopVar(converter, loc, indexVal, argSymbol);
}

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Extra whitespace

firOpBuilder.setInsertionPointAfter(storeOp);
}

Expand Down Expand Up @@ -1660,6 +1688,30 @@ genLoopNestClauses(lower::AbstractConverter &converter,
cp.processCollapse(loc, eval, clauseOps, iv);

clauseOps.loopInclusive = converter.getFirOpBuilder().getUnitAttr();

fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
for (auto &clause : clauses) {
if (clause.id == llvm::omp::Clause::OMPC_collapse) {
const auto &collapse = std::get<clause::Collapse>(clause.u);
int64_t collapseValue = evaluate::ToInt64(collapse.v).value();
clauseOps.numCollapse = firOpBuilder.getI64IntegerAttr(collapseValue);
} else if (clause.id == llvm::omp::Clause::OMPC_sizes) {
// This case handles the stand-alone tiling construct
const auto &sizes = std::get<clause::Sizes>(clause.u);
llvm::SmallVector<int64_t> sizeValues;
for (auto &size : sizes.v) {
int64_t sizeValue = evaluate::ToInt64(size).value();
sizeValues.push_back(sizeValue);
}
clauseOps.tileSizes = sizeValues;
}
}

llvm::SmallVector<int64_t> sizeValues;
auto *ompCons{eval.getIf<parser::OpenMPConstruct>()};
collectTileSizesFromOpenMPConstruct(ompCons, sizeValues, semaCtx);
if (sizeValues.size() > 0)
clauseOps.tileSizes = sizeValues;
}

static void genLoopClauses(
Expand Down Expand Up @@ -2036,9 +2088,9 @@ static mlir::omp::LoopNestOp genLoopNestOp(
return llvm::SmallVector<const semantics::Symbol *>(iv);
};

auto *nestedEval =
getCollapsedLoopEval(eval, getCollapseValue(item->clauses));

uint64_t nestValue = getCollapseValue(item->clauses);
nestValue = nestValue < iv.size() ? iv.size() : nestValue;
auto *nestedEval = getCollapsedLoopEval(eval, nestValue);
return genOpWithBody<mlir::omp::LoopNestOp>(
OpWithBodyGenInfo(converter, symTable, semaCtx, loc, *nestedEval,
directive)
Expand Down Expand Up @@ -3449,13 +3501,9 @@ static void genOMPDispatch(lower::AbstractConverter &converter,
newOp = genTeamsOp(converter, symTable, stmtCtx, semaCtx, eval, loc, queue,
item);
break;
case llvm::omp::Directive::OMPD_tile: {
unsigned version = semaCtx.langOptions().OpenMPVersion;
if (!semaCtx.langOptions().OpenMPSimd)
TODO(loc, "Unhandled loop directive (" +
llvm::omp::getOpenMPDirectiveName(dir, version) + ")");
case llvm::omp::Directive::OMPD_tile:
newOp = genLoopOp(converter, symTable, semaCtx, eval, loc, queue, item);
break;
}
case llvm::omp::Directive::OMPD_unroll:
genUnrollOp(converter, symTable, stmtCtx, semaCtx, eval, loc, queue, item);
break;
Expand Down Expand Up @@ -3890,6 +3938,7 @@ static void genOMP(lower::AbstractConverter &converter, lower::SymMap &symTable,
std::get<parser::OmpBeginLoopDirective>(loopConstruct.t);
List<Clause> clauses = makeClauses(
std::get<parser::OmpClauseList>(beginLoopDirective.t), semaCtx);

if (auto &endLoopDirective =
std::get<std::optional<parser::OmpEndLoopDirective>>(
loopConstruct.t)) {
Expand Down Expand Up @@ -4021,18 +4070,6 @@ void Fortran::lower::genOpenMPSymbolProperties(
lower::genDeclareTargetIntGlobal(converter, var);
}

int64_t
Fortran::lower::getCollapseValue(const parser::OmpClauseList &clauseList) {
for (const parser::OmpClause &clause : clauseList.v) {
if (const auto &collapseClause =
std::get_if<parser::OmpClause::Collapse>(&clause.u)) {
const auto *expr = semantics::GetExpr(collapseClause->v);
return evaluate::ToInt64(*expr).value();
}
}
return 1;
}

void Fortran::lower::genThreadprivateOp(lower::AbstractConverter &converter,
const lower::pft::Variable &var) {
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
Expand Down
120 changes: 112 additions & 8 deletions flang/lib/Lower/OpenMP/Utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
#include "Utils.h"

#include "ClauseFinder.h"
#include "flang/Evaluate/fold.h"
#include "flang/Lower/OpenMP/Clauses.h"
#include <flang/Lower/AbstractConverter.h>
#include <flang/Lower/ConvertType.h>
Expand All @@ -24,10 +25,30 @@
#include <flang/Parser/parse-tree.h>
#include <flang/Parser/tools.h>
#include <flang/Semantics/tools.h>
#include <flang/Semantics/type.h>
#include <llvm/Support/CommandLine.h>

#include <iterator>

using namespace Fortran::semantics;

template <typename T>
MaybeIntExpr EvaluateIntExpr(SemanticsContext &context, const T &expr) {
if (MaybeExpr maybeExpr{
Fold(context.foldingContext(), AnalyzeExpr(context, expr))}) {
if (auto *intExpr{Fortran::evaluate::UnwrapExpr<SomeIntExpr>(*maybeExpr)}) {
return std::move(*intExpr);
}
}
return std::nullopt;
}

template <typename T>
std::optional<std::int64_t> EvaluateInt64(SemanticsContext &context,
const T &expr) {
return Fortran::evaluate::ToInt64(EvaluateIntExpr(context, expr));
}

llvm::cl::opt<bool> treatIndexAsSection(
"openmp-treat-index-as-section",
llvm::cl::desc("In the OpenMP data clauses treat `a(N)` as `a(N:N)`."),
Expand All @@ -38,14 +59,21 @@ namespace lower {
namespace omp {

int64_t getCollapseValue(const List<Clause> &clauses) {
auto iter = llvm::find_if(clauses, [](const Clause &clause) {
return clause.id == llvm::omp::Clause::OMPC_collapse;
});
if (iter != clauses.end()) {
const auto &collapse = std::get<clause::Collapse>(iter->u);
return evaluate::ToInt64(collapse.v).value();
int64_t collapseValue = 1;
int64_t numTileSizes = 0;
for (auto &clause : clauses) {
if (clause.id == llvm::omp::Clause::OMPC_collapse) {
const auto &collapse = std::get<clause::Collapse>(clause.u);
collapseValue = evaluate::ToInt64(collapse.v).value();
} else if (clause.id == llvm::omp::Clause::OMPC_sizes) {
const auto &sizes = std::get<clause::Sizes>(clause.u);
numTileSizes = sizes.v.size();
}
Comment on lines +65 to +71
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This mixes the directive that collapse goes on (DO, DISTRIBUTE, ..) and those sizes goes on (TILE). During resolving merge conflicts of #145917 and #144785 the approach of handling tile like a compound construct already caused problems. E.g. it could be tile twice with different sizes, or unroll partial followed by tile vs the other way around, and there is no way to disambiguate them.

So it will need to be tangled apart eventually, and I would appreciate if we could avoid the assumption that there is just one tile which is always first.

}
return 1;

collapseValue = collapseValue - numTileSizes;
int64_t result = collapseValue > numTileSizes ? collapseValue : numTileSizes;
return result;
Comment on lines +74 to +76
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you explain to me why we need this calculation? To see the effect of this, I tried replacing these few lines with simply return collapseValue; and ran all tests but no tests failed. So it seems this part is not tested. A test can also help explaining the purpose of the change.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The general thing this computes is the number of loops that need to be considered in the source code. If you have collapse(4) on a loop nest with 2 loops that would be incorrect since we can max collapse 2 loops. However tiling creates new loops, so collapse(4) would theoretically be legal if tiling is done first e.g. tile(5,10) since that will result in 4 loops. This is not really testable though since collapse requires independent loops, which is only true for the 2 outer loops after tiling is done. There is a check for this, and an error message is given if the collapse value is larger than the number of loops that are tiled to prevent incorrect code. We could just use numTileSizes if that is present, but if collapse could handle dependent loops in the future the above calculation should be the correct one.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Support for collapsing grid- and tile-loops have only been added in OpenMP 6.0. Before that, there was a restriction:
image

The value of the collapse clause should not really be modified implicitly. Is there is check in FortranSemantics that ensures collapse is not larger than the number of grid loops?

}

void genObjectList(const ObjectList &objects,
Expand Down Expand Up @@ -608,11 +636,52 @@ static void convertLoopBounds(lower::AbstractConverter &converter,
}
}

// Populates the sizes vector with values if the given OpenMPConstruct
// Contains a loop construct with an inner tiling construct.
void collectTileSizesFromOpenMPConstruct(
const parser::OpenMPConstruct *ompCons,
llvm::SmallVectorImpl<int64_t> &tileSizes, SemanticsContext &semaCtx) {
if (!ompCons)
return;

if (auto *ompLoop{std::get_if<parser::OpenMPLoopConstruct>(&ompCons->u)}) {
const auto &nestedOptional =
std::get<std::optional<parser::NestedConstruct>>(ompLoop->t);
assert(nestedOptional.has_value() &&
"Expected a DoConstruct or OpenMPLoopConstruct");
const auto *innerConstruct =
std::get_if<common::Indirection<parser::OpenMPLoopConstruct>>(
&(nestedOptional.value()));
if (innerConstruct) {
const auto &innerLoopDirective = innerConstruct->value();
const auto &innerBegin =
std::get<parser::OmpBeginLoopDirective>(innerLoopDirective.t);
const auto &innerDirective =
std::get<parser::OmpLoopDirective>(innerBegin.t).v;

if (innerDirective == llvm::omp::Directive::OMPD_tile) {
// Get the size values from parse tree and convert to a vector
const auto &innerClauseList{
std::get<parser::OmpClauseList>(innerBegin.t)};
for (const auto &clause : innerClauseList.v)
if (const auto tclause{
std::get_if<parser::OmpClause::Sizes>(&clause.u)}) {
for (auto &tval : tclause->v) {
if (const auto v{EvaluateInt64(semaCtx, tval)})
tileSizes.push_back(*v);
}
}
}
}
}
}

bool collectLoopRelatedInfo(
lower::AbstractConverter &converter, mlir::Location currentLocation,
lower::pft::Evaluation &eval, const omp::List<omp::Clause> &clauses,
mlir::omp::LoopRelatedClauseOps &result,
llvm::SmallVectorImpl<const semantics::Symbol *> &iv) {

bool found = false;
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();

Expand All @@ -629,6 +698,42 @@ bool collectLoopRelatedInfo(
found = true;
}

// Collect sizes from tile directive if present
std::int64_t sizesLengthValue = 0l;
if (auto *ompCons{eval.getIf<parser::OpenMPConstruct>()}) {
if (auto *ompLoop{std::get_if<parser::OpenMPLoopConstruct>(&ompCons->u)}) {
const auto &nestedOptional =
std::get<std::optional<parser::NestedConstruct>>(ompLoop->t);
assert(nestedOptional.has_value() &&
"Expected a DoConstruct or OpenMPLoopConstruct");
const auto *innerConstruct =
std::get_if<common::Indirection<parser::OpenMPLoopConstruct>>(
&(nestedOptional.value()));
if (innerConstruct) {
const auto &innerLoopDirective = innerConstruct->value();
const auto &innerBegin =
std::get<parser::OmpBeginLoopDirective>(innerLoopDirective.t);
const auto &innerDirective =
std::get<parser::OmpLoopDirective>(innerBegin.t).v;

if (innerDirective == llvm::omp::Directive::OMPD_tile) {
// Get the size values from parse tree and convert to a vector
const auto &innerClauseList{
std::get<parser::OmpClauseList>(innerBegin.t)};
for (const auto &clause : innerClauseList.v)
if (const auto tclause{
std::get_if<parser::OmpClause::Sizes>(&clause.u)}) {
sizesLengthValue = tclause->v.size();
found = true;
}
}
}
}
}

collapseValue = collapseValue - sizesLengthValue;
collapseValue =
collapseValue < sizesLengthValue ? sizesLengthValue : collapseValue;
std::size_t loopVarTypeSize = 0;
do {
lower::pft::Evaluation *doLoop =
Expand Down Expand Up @@ -661,7 +766,6 @@ bool collectLoopRelatedInfo(
} while (collapseValue > 0);

convertLoopBounds(converter, currentLocation, result, loopVarTypeSize);

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nit] unrelated change

return found;
}

Expand Down
5 changes: 5 additions & 0 deletions flang/lib/Lower/OpenMP/Utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,11 @@ bool collectLoopRelatedInfo(
mlir::omp::LoopRelatedClauseOps &result,
llvm::SmallVectorImpl<const semantics::Symbol *> &iv);

void collectTileSizesFromOpenMPConstruct(
const parser::OpenMPConstruct *ompCons,
llvm::SmallVectorImpl<int64_t> &tileSizes,
Fortran::semantics::SemanticsContext &semaCtx);

} // namespace omp
} // namespace lower
} // namespace Fortran
Expand Down
Loading