Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
40 commits
Select commit Hold shift + click to select a range
548684f
Initial implementation of tiling.
jsjodin Mar 14, 2025
02fb537
Fix tests and limit the nesting of construct to only tiling.
jsjodin Jun 10, 2025
3efaa77
Enable stand-alone tiling, but it gives a warning and converting to s…
jsjodin Jun 10, 2025
b7f09a0
Add minimal test, remove debug print.
jsjodin Jun 11, 2025
75b0ab5
Fix formatting
jsjodin Jun 13, 2025
dccb7f4
Fix formatting
jsjodin Jun 14, 2025
1e63230
Fix test.
jsjodin Jun 19, 2025
4b5412a
Add more mlir tests. Set collapse value when lowering from SCF to Ope…
jsjodin Jun 20, 2025
411882b
Use llvm::SmallVector instead of std::stack
jsjodin Jun 20, 2025
b39287e
Improve test a bit to make sure IVs are used as expected.
jsjodin Jun 21, 2025
5f51565
Fix comments to clarify canonicalization.
jsjodin Jun 21, 2025
e1eaf9a
Special handling of tile directive when dealing with start end end lo…
jsjodin Jun 21, 2025
ac41499
Inline functions.
jsjodin Jun 21, 2025
9f74cf1
Remove debug code.
jsjodin Jun 23, 2025
bb9132c
Reuse loop op lowering, add comment.
jsjodin Jun 23, 2025
6681b4e
Fix formatting.
jsjodin Jun 23, 2025
54593af
Remove curly braces.
jsjodin Jun 23, 2025
93d9952
Avoid attaching the sizes clause to the parent construct, instead fin…
jsjodin Jun 25, 2025
a738a56
Fix formatting
jsjodin Jun 25, 2025
9757360
Fix unparse and add a test for nested loop constructs.
jsjodin Jun 26, 2025
9e35a6f
Use more convenient function to get OpenMPLoopConstruct. Fix comments.
jsjodin Jun 26, 2025
8dfaf57
Fix formatting.
jsjodin Jun 26, 2025
84c5cc1
Fix merge problems related to the different representations used for …
jsjodin Aug 9, 2025
53ab195
Fix bugs introduced when merging.
jsjodin Aug 9, 2025
fa6e323
Move include
jsjodin Aug 11, 2025
bce250c
Remove unused code. Currently the canonicalize-omp can only handle a …
jsjodin Aug 19, 2025
8ae7ff6
Address review comments.
jsjodin Aug 22, 2025
a3d3db7
Undo unrelated change.
jsjodin Aug 22, 2025
8510c40
Remove stand-alone tiling.
jsjodin Aug 26, 2025
d462a6c
Revert unused changes.
jsjodin Aug 26, 2025
f59c027
Don't do codegen for tiling if it is an inner construct.
jsjodin Aug 26, 2025
b21af72
Remove unused stand-alone tiling code. Fix typo.
jsjodin Aug 27, 2025
7820ef5
Remove unused code.
jsjodin Aug 27, 2025
995fc53
Address review comments
jsjodin Aug 27, 2025
62b3ed1
Move include line
jsjodin Aug 28, 2025
71ed212
Move include to correct location hopefully.
jsjodin Aug 28, 2025
a0435d0
Make sure collapse is progagated correctly on the host side.
jsjodin Aug 28, 2025
4f2f7eb
Address more review comments.
jsjodin Aug 28, 2025
a687d97
Address review comments
jsjodin Sep 3, 2025
d1b69c9
Use shared helper function to find the sizes clause.
jsjodin Sep 9, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion flang/include/flang/Lower/OpenMP.h
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,6 @@ void genOpenMPDeclarativeConstruct(AbstractConverter &,
void genOpenMPSymbolProperties(AbstractConverter &converter,
const pft::Variable &var);

int64_t getCollapseValue(const Fortran::parser::OmpClauseList &clauseList);
void genThreadprivateOp(AbstractConverter &, const pft::Variable &);
void genDeclareTargetIntGlobal(AbstractConverter &, const pft::Variable &);
bool isOpenMPTargetConstruct(const parser::OpenMPConstruct &);
Expand Down
18 changes: 15 additions & 3 deletions flang/lib/Lower/OpenMP/ClauseProcessor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -273,10 +273,15 @@ bool ClauseProcessor::processCancelDirectiveName(

bool ClauseProcessor::processCollapse(
mlir::Location currentLocation, lower::pft::Evaluation &eval,
mlir::omp::LoopRelatedClauseOps &result,
mlir::omp::LoopRelatedClauseOps &loopResult,
mlir::omp::CollapseClauseOps &collapseResult,
llvm::SmallVectorImpl<const semantics::Symbol *> &iv) const {
return collectLoopRelatedInfo(converter, currentLocation, eval, clauses,
result, iv);

int64_t numCollapse = collectLoopRelatedInfo(converter, currentLocation, eval,
clauses, loopResult, iv);
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
collapseResult.collapseNumLoops = firOpBuilder.getI64IntegerAttr(numCollapse);
return numCollapse > 1;
}

bool ClauseProcessor::processDevice(lower::StatementContext &stmtCtx,
Expand Down Expand Up @@ -522,6 +527,13 @@ bool ClauseProcessor::processProcBind(
return false;
}

bool ClauseProcessor::processTileSizes(
lower::pft::Evaluation &eval, mlir::omp::LoopNestOperands &result) const {
auto *ompCons{eval.getIf<parser::OpenMPConstruct>()};
collectTileSizesFromOpenMPConstruct(ompCons, result.tileSizes, semaCtx);
return !result.tileSizes.empty();
}

bool ClauseProcessor::processSafelen(
mlir::omp::SafelenClauseOps &result) const {
if (auto *clause = findUniqueClause<omp::clause::Safelen>()) {
Expand Down
5 changes: 4 additions & 1 deletion flang/lib/Lower/OpenMP/ClauseProcessor.h
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,8 @@ class ClauseProcessor {
mlir::omp::CancelDirectiveNameClauseOps &result) const;
bool
processCollapse(mlir::Location currentLocation, lower::pft::Evaluation &eval,
mlir::omp::LoopRelatedClauseOps &result,
mlir::omp::LoopRelatedClauseOps &loopResult,
mlir::omp::CollapseClauseOps &collapseResult,
llvm::SmallVectorImpl<const semantics::Symbol *> &iv) const;
bool processDevice(lower::StatementContext &stmtCtx,
mlir::omp::DeviceClauseOps &result) const;
Expand Down Expand Up @@ -98,6 +99,8 @@ class ClauseProcessor {
bool processPriority(lower::StatementContext &stmtCtx,
mlir::omp::PriorityClauseOps &result) const;
bool processProcBind(mlir::omp::ProcBindClauseOps &result) const;
bool processTileSizes(lower::pft::Evaluation &eval,
mlir::omp::LoopNestOperands &result) const;
bool processSafelen(mlir::omp::SafelenClauseOps &result) const;
bool processSchedule(lower::StatementContext &stmtCtx,
mlir::omp::ScheduleClauseOps &result) const;
Expand Down
31 changes: 10 additions & 21 deletions flang/lib/Lower/OpenMP/OpenMP.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -503,7 +503,7 @@ static void processHostEvalClauses(lower::AbstractConverter &converter,
[[fallthrough]];
case OMPD_distribute:
case OMPD_distribute_simd:
cp.processCollapse(loc, eval, hostInfo->ops, hostInfo->iv);
cp.processCollapse(loc, eval, hostInfo->ops, hostInfo->ops, hostInfo->iv);
break;

case OMPD_teams:
Expand All @@ -522,7 +522,7 @@ static void processHostEvalClauses(lower::AbstractConverter &converter,
[[fallthrough]];
case OMPD_target_teams_distribute:
case OMPD_target_teams_distribute_simd:
cp.processCollapse(loc, eval, hostInfo->ops, hostInfo->iv);
cp.processCollapse(loc, eval, hostInfo->ops, hostInfo->ops, hostInfo->iv);
cp.processNumTeams(stmtCtx, hostInfo->ops);
break;

Expand All @@ -533,7 +533,7 @@ static void processHostEvalClauses(lower::AbstractConverter &converter,
cp.processNumTeams(stmtCtx, hostInfo->ops);
[[fallthrough]];
case OMPD_loop:
cp.processCollapse(loc, eval, hostInfo->ops, hostInfo->iv);
cp.processCollapse(loc, eval, hostInfo->ops, hostInfo->ops, hostInfo->iv);
break;

case OMPD_teams_workdistribute:
Expand Down Expand Up @@ -1569,9 +1569,10 @@ genLoopNestClauses(lower::AbstractConverter &converter,

HostEvalInfo *hostEvalInfo = getHostEvalInfoStackTop(converter);
if (!hostEvalInfo || !hostEvalInfo->apply(clauseOps, iv))
cp.processCollapse(loc, eval, clauseOps, iv);
cp.processCollapse(loc, eval, clauseOps, clauseOps, iv);

clauseOps.loopInclusive = converter.getFirOpBuilder().getUnitAttr();
cp.processTileSizes(eval, clauseOps);
}

static void genLoopClauses(
Expand Down Expand Up @@ -1948,9 +1949,9 @@ static mlir::omp::LoopNestOp genLoopNestOp(
return llvm::SmallVector<const semantics::Symbol *>(iv);
};

auto *nestedEval =
getCollapsedLoopEval(eval, getCollapseValue(item->clauses));

uint64_t nestValue = getCollapseValue(item->clauses);
nestValue = nestValue < iv.size() ? iv.size() : nestValue;
auto *nestedEval = getCollapsedLoopEval(eval, nestValue);
return genOpWithBody<mlir::omp::LoopNestOp>(
OpWithBodyGenInfo(converter, symTable, semaCtx, loc, *nestedEval,
directive)
Expand Down Expand Up @@ -3843,8 +3844,8 @@ static void genOMP(lower::AbstractConverter &converter, lower::SymMap &symTable,
parser::omp::GetOmpDirectiveName(*ompNestedLoopCons).v;
switch (nestedDirective) {
case llvm::omp::Directive::OMPD_tile:
// Emit the omp.loop_nest with annotation for tiling
genOMP(converter, symTable, semaCtx, eval, ompNestedLoopCons->value());
// Skip OMPD_tile since the tile sizes will be retrieved when
// generating the omp.loop_nest op.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: Add a TODO for when the loop transformation is applied on its own, with no worksharing or similar construct associated.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In that case it will not be a nested construct and will hit the TODO on line 3476.

break;
default: {
unsigned version = semaCtx.langOptions().OpenMPVersion;
Expand Down Expand Up @@ -3957,18 +3958,6 @@ void Fortran::lower::genOpenMPSymbolProperties(
lower::genDeclareTargetIntGlobal(converter, var);
}

int64_t
Fortran::lower::getCollapseValue(const parser::OmpClauseList &clauseList) {
for (const parser::OmpClause &clause : clauseList.v) {
if (const auto &collapseClause =
std::get_if<parser::OmpClause::Collapse>(&clause.u)) {
const auto *expr = semantics::GetExpr(collapseClause->v);
return evaluate::ToInt64(*expr).value();
}
}
return 1;
}

void Fortran::lower::genThreadprivateOp(lower::AbstractConverter &converter,
const lower::pft::Variable &var) {
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
Expand Down
92 changes: 88 additions & 4 deletions flang/lib/Lower/OpenMP/Utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
#include "Utils.h"

#include "ClauseFinder.h"
#include "flang/Evaluate/fold.h"
#include "flang/Lower/OpenMP/Clauses.h"
#include <flang/Lower/AbstractConverter.h>
#include <flang/Lower/ConvertType.h>
Expand All @@ -24,11 +25,32 @@
#include <flang/Parser/parse-tree.h>
#include <flang/Parser/tools.h>
#include <flang/Semantics/tools.h>
#include <flang/Semantics/type.h>
#include <flang/Utils/OpenMP.h>
#include <llvm/Support/CommandLine.h>

#include <iterator>

template <typename T>
Fortran::semantics::MaybeIntExpr
EvaluateIntExpr(Fortran::semantics::SemanticsContext &context, const T &expr) {
if (Fortran::semantics::MaybeExpr maybeExpr{
Fold(context.foldingContext(), AnalyzeExpr(context, expr))}) {
if (auto *intExpr{
Fortran::evaluate::UnwrapExpr<Fortran::semantics::SomeIntExpr>(
*maybeExpr)}) {
return std::move(*intExpr);
}
}
return std::nullopt;
}

template <typename T>
std::optional<std::int64_t>
EvaluateInt64(Fortran::semantics::SemanticsContext &context, const T &expr) {
return Fortran::evaluate::ToInt64(EvaluateIntExpr(context, expr));
}

llvm::cl::opt<bool> treatIndexAsSection(
"openmp-treat-index-as-section",
llvm::cl::desc("In the OpenMP data clauses treat `a(N)` as `a(N:N)`."),
Expand Down Expand Up @@ -577,12 +599,64 @@ static void convertLoopBounds(lower::AbstractConverter &converter,
}
}

bool collectLoopRelatedInfo(
// Helper function that finds the sizes clause in a inner OMPD_tile directive
// and passes the sizes clause to the callback function if found.
static void processTileSizesFromOpenMPConstruct(
const parser::OpenMPConstruct *ompCons,
std::function<void(const parser::OmpClause::Sizes *)> processFun) {
if (!ompCons)
return;
if (auto *ompLoop{std::get_if<parser::OpenMPLoopConstruct>(&ompCons->u)}) {
const auto &nestedOptional =
std::get<std::optional<parser::NestedConstruct>>(ompLoop->t);
assert(nestedOptional.has_value() &&
"Expected a DoConstruct or OpenMPLoopConstruct");
const auto *innerConstruct =
std::get_if<common::Indirection<parser::OpenMPLoopConstruct>>(
&(nestedOptional.value()));
if (innerConstruct) {
const auto &innerLoopDirective = innerConstruct->value();
const auto &innerBegin =
std::get<parser::OmpBeginLoopDirective>(innerLoopDirective.t);
const auto &innerDirective =
std::get<parser::OmpLoopDirective>(innerBegin.t).v;

if (innerDirective == llvm::omp::Directive::OMPD_tile) {
// Get the size values from parse tree and convert to a vector.
const auto &innerClauseList{
std::get<parser::OmpClauseList>(innerBegin.t)};
for (const auto &clause : innerClauseList.v) {
if (const auto tclause{
std::get_if<parser::OmpClause::Sizes>(&clause.u)}) {
processFun(tclause);
break;
}
}
}
}
}
}

/// Populates the sizes vector with values if the given OpenMPConstruct
/// contains a loop construct with an inner tiling construct.
void collectTileSizesFromOpenMPConstruct(
const parser::OpenMPConstruct *ompCons,
llvm::SmallVectorImpl<int64_t> &tileSizes,
Fortran::semantics::SemanticsContext &semaCtx) {
processTileSizesFromOpenMPConstruct(
ompCons, [&](const parser::OmpClause::Sizes *tclause) {
for (auto &tval : tclause->v)
if (const auto v{EvaluateInt64(semaCtx, tval)})
tileSizes.push_back(*v);
});
}

int64_t collectLoopRelatedInfo(
lower::AbstractConverter &converter, mlir::Location currentLocation,
lower::pft::Evaluation &eval, const omp::List<omp::Clause> &clauses,
mlir::omp::LoopRelatedClauseOps &result,
llvm::SmallVectorImpl<const semantics::Symbol *> &iv) {
bool found = false;
int64_t numCollapse = 1;
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();

// Collect the loops to collapse.
Expand All @@ -595,9 +669,19 @@ bool collectLoopRelatedInfo(
if (auto *clause =
ClauseFinder::findUniqueClause<omp::clause::Collapse>(clauses)) {
collapseValue = evaluate::ToInt64(clause->v).value();
found = true;
numCollapse = collapseValue;
}

// Collect sizes from tile directive if present.
std::int64_t sizesLengthValue = 0l;
if (auto *ompCons{eval.getIf<parser::OpenMPConstruct>()}) {
processTileSizesFromOpenMPConstruct(
ompCons, [&](const parser::OmpClause::Sizes *tclause) {
sizesLengthValue = tclause->v.size();
});
}

collapseValue = std::max(collapseValue, sizesLengthValue);
std::size_t loopVarTypeSize = 0;
do {
lower::pft::Evaluation *doLoop =
Expand Down Expand Up @@ -631,7 +715,7 @@ bool collectLoopRelatedInfo(

convertLoopBounds(converter, currentLocation, result, loopVarTypeSize);

return found;
return numCollapse;
}

} // namespace omp
Expand Down
7 changes: 6 additions & 1 deletion flang/lib/Lower/OpenMP/Utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -159,12 +159,17 @@ void genObjectList(const ObjectList &objects,
void lastprivateModifierNotSupported(const omp::clause::Lastprivate &lastp,
mlir::Location loc);

bool collectLoopRelatedInfo(
int64_t collectLoopRelatedInfo(
lower::AbstractConverter &converter, mlir::Location currentLocation,
lower::pft::Evaluation &eval, const omp::List<omp::Clause> &clauses,
mlir::omp::LoopRelatedClauseOps &result,
llvm::SmallVectorImpl<const semantics::Symbol *> &iv);

void collectTileSizesFromOpenMPConstruct(
const parser::OpenMPConstruct *ompCons,
llvm::SmallVectorImpl<int64_t> &tileSizes,
Fortran::semantics::SemanticsContext &semaCtx);

} // namespace omp
} // namespace lower
} // namespace Fortran
Expand Down
Loading