13
13
#include " Utils.h"
14
14
15
15
#include " ClauseFinder.h"
16
+ #include " flang/Evaluate/fold.h"
16
17
#include " flang/Lower/OpenMP/Clauses.h"
17
18
#include < flang/Lower/AbstractConverter.h>
18
19
#include < flang/Lower/ConvertType.h>
24
25
#include < flang/Parser/parse-tree.h>
25
26
#include < flang/Parser/tools.h>
26
27
#include < flang/Semantics/tools.h>
28
+ #include < flang/Semantics/type.h>
27
29
#include < flang/Utils/OpenMP.h>
28
30
#include < llvm/Support/CommandLine.h>
29
31
30
32
#include < iterator>
31
33
34
+ template <typename T>
35
+ Fortran::semantics::MaybeIntExpr
36
+ EvaluateIntExpr (Fortran::semantics::SemanticsContext &context, const T &expr) {
37
+ if (Fortran::semantics::MaybeExpr maybeExpr{
38
+ Fold (context.foldingContext (), AnalyzeExpr (context, expr))}) {
39
+ if (auto *intExpr{
40
+ Fortran::evaluate::UnwrapExpr<Fortran::semantics::SomeIntExpr>(
41
+ *maybeExpr)}) {
42
+ return std::move (*intExpr);
43
+ }
44
+ }
45
+ return std::nullopt ;
46
+ }
47
+
48
+ template <typename T>
49
+ std::optional<std::int64_t >
50
+ EvaluateInt64 (Fortran::semantics::SemanticsContext &context, const T &expr) {
51
+ return Fortran::evaluate::ToInt64 (EvaluateIntExpr (context, expr));
52
+ }
53
+
32
54
llvm::cl::opt<bool > treatIndexAsSection (
33
55
" openmp-treat-index-as-section" ,
34
56
llvm::cl::desc (" In the OpenMP data clauses treat `a(N)` as `a(N:N)`." ),
@@ -577,12 +599,64 @@ static void convertLoopBounds(lower::AbstractConverter &converter,
577
599
}
578
600
}
579
601
580
- bool collectLoopRelatedInfo (
602
+ // Helper function that finds the sizes clause in a inner OMPD_tile directive
603
+ // and passes the sizes clause to the callback function if found.
604
+ static void processTileSizesFromOpenMPConstruct (
605
+ const parser::OpenMPConstruct *ompCons,
606
+ std::function<void (const parser::OmpClause::Sizes *)> processFun) {
607
+ if (!ompCons)
608
+ return ;
609
+ if (auto *ompLoop{std::get_if<parser::OpenMPLoopConstruct>(&ompCons->u )}) {
610
+ const auto &nestedOptional =
611
+ std::get<std::optional<parser::NestedConstruct>>(ompLoop->t );
612
+ assert (nestedOptional.has_value () &&
613
+ " Expected a DoConstruct or OpenMPLoopConstruct" );
614
+ const auto *innerConstruct =
615
+ std::get_if<common::Indirection<parser::OpenMPLoopConstruct>>(
616
+ &(nestedOptional.value ()));
617
+ if (innerConstruct) {
618
+ const auto &innerLoopDirective = innerConstruct->value ();
619
+ const auto &innerBegin =
620
+ std::get<parser::OmpBeginLoopDirective>(innerLoopDirective.t );
621
+ const auto &innerDirective =
622
+ std::get<parser::OmpLoopDirective>(innerBegin.t ).v ;
623
+
624
+ if (innerDirective == llvm::omp::Directive::OMPD_tile) {
625
+ // Get the size values from parse tree and convert to a vector.
626
+ const auto &innerClauseList{
627
+ std::get<parser::OmpClauseList>(innerBegin.t )};
628
+ for (const auto &clause : innerClauseList.v ) {
629
+ if (const auto tclause{
630
+ std::get_if<parser::OmpClause::Sizes>(&clause.u )}) {
631
+ processFun (tclause);
632
+ break ;
633
+ }
634
+ }
635
+ }
636
+ }
637
+ }
638
+ }
639
+
640
+ // / Populates the sizes vector with values if the given OpenMPConstruct
641
+ // / contains a loop construct with an inner tiling construct.
642
+ void collectTileSizesFromOpenMPConstruct (
643
+ const parser::OpenMPConstruct *ompCons,
644
+ llvm::SmallVectorImpl<int64_t > &tileSizes,
645
+ Fortran::semantics::SemanticsContext &semaCtx) {
646
+ processTileSizesFromOpenMPConstruct (
647
+ ompCons, [&](const parser::OmpClause::Sizes *tclause) {
648
+ for (auto &tval : tclause->v )
649
+ if (const auto v{EvaluateInt64 (semaCtx, tval)})
650
+ tileSizes.push_back (*v);
651
+ });
652
+ }
653
+
654
+ int64_t collectLoopRelatedInfo (
581
655
lower::AbstractConverter &converter, mlir::Location currentLocation,
582
656
lower::pft::Evaluation &eval, const omp::List<omp::Clause> &clauses,
583
657
mlir::omp::LoopRelatedClauseOps &result,
584
658
llvm::SmallVectorImpl<const semantics::Symbol *> &iv) {
585
- bool found = false ;
659
+ int64_t numCollapse = 1 ;
586
660
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder ();
587
661
588
662
// Collect the loops to collapse.
@@ -595,9 +669,19 @@ bool collectLoopRelatedInfo(
595
669
if (auto *clause =
596
670
ClauseFinder::findUniqueClause<omp::clause::Collapse>(clauses)) {
597
671
collapseValue = evaluate::ToInt64 (clause->v ).value ();
598
- found = true ;
672
+ numCollapse = collapseValue;
673
+ }
674
+
675
+ // Collect sizes from tile directive if present.
676
+ std::int64_t sizesLengthValue = 0l ;
677
+ if (auto *ompCons{eval.getIf <parser::OpenMPConstruct>()}) {
678
+ processTileSizesFromOpenMPConstruct (
679
+ ompCons, [&](const parser::OmpClause::Sizes *tclause) {
680
+ sizesLengthValue = tclause->v .size ();
681
+ });
599
682
}
600
683
684
+ collapseValue = std::max (collapseValue, sizesLengthValue);
601
685
std::size_t loopVarTypeSize = 0 ;
602
686
do {
603
687
lower::pft::Evaluation *doLoop =
@@ -631,7 +715,7 @@ bool collectLoopRelatedInfo(
631
715
632
716
convertLoopBounds (converter, currentLocation, result, loopVarTypeSize);
633
717
634
- return found ;
718
+ return numCollapse ;
635
719
}
636
720
637
721
} // namespace omp
0 commit comments