Skip to content

Commit 53c4338

Browse files
rofirrimgithub-actions[bot]
authored andcommitted
Automerge: [Clang][OpenMP] Add an additional class to hold data that will be shared between all loop transformations (#155849)
This is preparatory work for the implementation of `#pragma omp fuse` in llvm/llvm-project#139293 **Note**: this change builds on top of llvm/llvm-project#155848 This change adds an additional class to hold data that will be shared between all loop transformations: those that apply to canonical loop nests (the majority) and those that apply to canonical loop sequences (`fuse` in OpenMP 6.0). This class is not a statement by itself and its goal is to avoid having to replicate information between classes. Also simplfiy the way we handle the "generated loops" information as we currently only need to know if it is zero or non-zero.
2 parents e3c54ee + b9f84bc commit 53c4338

File tree

5 files changed

+52
-40
lines changed

5 files changed

+52
-40
lines changed

clang/include/clang/AST/StmtOpenMP.h

Lines changed: 32 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -956,30 +956,46 @@ class OMPLoopBasedDirective : public OMPExecutableDirective {
956956
}
957957
};
958958

959+
/// Common class of data shared between
960+
/// OMPCanonicalLoopNestTransformationDirective and transformations over
961+
/// canonical loop sequences.
962+
class OMPLoopTransformationDirective {
963+
/// Number of (top-level) generated loops.
964+
/// This value is 1 for most transformations as they only map one loop nest
965+
/// into another.
966+
/// Some loop transformations (like a non-partial 'unroll') may not generate
967+
/// a loop nest, so this would be 0.
968+
/// Some loop transformations (like 'fuse' with looprange and 'split') may
969+
/// generate more than one loop nest, so the value would be >= 1.
970+
unsigned NumGeneratedTopLevelLoops = 1;
971+
972+
protected:
973+
void setNumGeneratedTopLevelLoops(unsigned N) {
974+
NumGeneratedTopLevelLoops = N;
975+
}
976+
977+
public:
978+
unsigned getNumGeneratedTopLevelLoops() const {
979+
return NumGeneratedTopLevelLoops;
980+
}
981+
};
982+
959983
/// The base class for all transformation directives of canonical loop nests.
960984
class OMPCanonicalLoopNestTransformationDirective
961-
: public OMPLoopBasedDirective {
985+
: public OMPLoopBasedDirective,
986+
public OMPLoopTransformationDirective {
962987
friend class ASTStmtReader;
963988

964-
/// Number of loops generated by this loop transformation.
965-
unsigned NumGeneratedLoops = 0;
966-
967989
protected:
968990
explicit OMPCanonicalLoopNestTransformationDirective(
969991
StmtClass SC, OpenMPDirectiveKind Kind, SourceLocation StartLoc,
970992
SourceLocation EndLoc, unsigned NumAssociatedLoops)
971993
: OMPLoopBasedDirective(SC, Kind, StartLoc, EndLoc, NumAssociatedLoops) {}
972994

973-
/// Set the number of loops generated by this loop transformation.
974-
void setNumGeneratedLoops(unsigned Num) { NumGeneratedLoops = Num; }
975-
976995
public:
977996
/// Return the number of associated (consumed) loops.
978997
unsigned getNumAssociatedLoops() const { return getLoopsNumber(); }
979998

980-
/// Return the number of loops generated by this loop transformation.
981-
unsigned getNumGeneratedLoops() const { return NumGeneratedLoops; }
982-
983999
/// Get the de-sugared statements after the loop transformation.
9841000
///
9851001
/// Might be nullptr if either the directive generates no loops and is handled
@@ -5560,9 +5576,7 @@ class OMPTileDirective final
55605576
unsigned NumLoops)
55615577
: OMPCanonicalLoopNestTransformationDirective(
55625578
OMPTileDirectiveClass, llvm::omp::OMPD_tile, StartLoc, EndLoc,
5563-
NumLoops) {
5564-
setNumGeneratedLoops(2 * NumLoops);
5565-
}
5579+
NumLoops) {}
55665580

55675581
void setPreInits(Stmt *PreInits) {
55685582
Data->getChildren()[PreInitsOffset] = PreInits;
@@ -5638,9 +5652,7 @@ class OMPStripeDirective final
56385652
unsigned NumLoops)
56395653
: OMPCanonicalLoopNestTransformationDirective(
56405654
OMPStripeDirectiveClass, llvm::omp::OMPD_stripe, StartLoc, EndLoc,
5641-
NumLoops) {
5642-
setNumGeneratedLoops(2 * NumLoops);
5643-
}
5655+
NumLoops) {}
56445656

56455657
void setPreInits(Stmt *PreInits) {
56465658
Data->getChildren()[PreInitsOffset] = PreInits;
@@ -5744,7 +5756,8 @@ class OMPUnrollDirective final
57445756
static OMPUnrollDirective *
57455757
Create(const ASTContext &C, SourceLocation StartLoc, SourceLocation EndLoc,
57465758
ArrayRef<OMPClause *> Clauses, Stmt *AssociatedStmt,
5747-
unsigned NumGeneratedLoops, Stmt *TransformedStmt, Stmt *PreInits);
5759+
unsigned NumGeneratedTopLevelLoops, Stmt *TransformedStmt,
5760+
Stmt *PreInits);
57485761

57495762
/// Build an empty '#pragma omp unroll' AST node for deserialization.
57505763
///
@@ -5794,9 +5807,7 @@ class OMPReverseDirective final
57945807
unsigned NumLoops)
57955808
: OMPCanonicalLoopNestTransformationDirective(
57965809
OMPReverseDirectiveClass, llvm::omp::OMPD_reverse, StartLoc, EndLoc,
5797-
NumLoops) {
5798-
setNumGeneratedLoops(NumLoops);
5799-
}
5810+
NumLoops) {}
58005811

58015812
void setPreInits(Stmt *PreInits) {
58025813
Data->getChildren()[PreInitsOffset] = PreInits;
@@ -5867,9 +5878,7 @@ class OMPInterchangeDirective final
58675878
SourceLocation EndLoc, unsigned NumLoops)
58685879
: OMPCanonicalLoopNestTransformationDirective(
58695880
OMPInterchangeDirectiveClass, llvm::omp::OMPD_interchange, StartLoc,
5870-
EndLoc, NumLoops) {
5871-
setNumGeneratedLoops(NumLoops);
5872-
}
5881+
EndLoc, NumLoops) {}
58735882

58745883
void setPreInits(Stmt *PreInits) {
58755884
Data->getChildren()[PreInitsOffset] = PreInits;

clang/lib/AST/StmtOpenMP.cpp

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -139,13 +139,14 @@ bool OMPLoopBasedDirective::doForAllLoops(
139139

140140
Stmt *TransformedStmt = Dir->getTransformedStmt();
141141
if (!TransformedStmt) {
142-
unsigned NumGeneratedLoops = Dir->getNumGeneratedLoops();
143-
if (NumGeneratedLoops == 0) {
142+
unsigned NumGeneratedTopLevelLoops =
143+
Dir->getNumGeneratedTopLevelLoops();
144+
if (NumGeneratedTopLevelLoops == 0) {
144145
// May happen if the loop transformation does not result in a
145146
// generated loop (such as full unrolling).
146147
break;
147148
}
148-
if (NumGeneratedLoops > 0) {
149+
if (NumGeneratedTopLevelLoops > 0) {
149150
// The loop transformation construct has generated loops, but these
150151
// may not have been generated yet due to being in a dependent
151152
// context.
@@ -447,16 +448,16 @@ OMPStripeDirective *OMPStripeDirective::CreateEmpty(const ASTContext &C,
447448
SourceLocation(), SourceLocation(), NumLoops);
448449
}
449450

450-
OMPUnrollDirective *
451-
OMPUnrollDirective::Create(const ASTContext &C, SourceLocation StartLoc,
452-
SourceLocation EndLoc, ArrayRef<OMPClause *> Clauses,
453-
Stmt *AssociatedStmt, unsigned NumGeneratedLoops,
454-
Stmt *TransformedStmt, Stmt *PreInits) {
455-
assert(NumGeneratedLoops <= 1 && "Unrolling generates at most one loop");
451+
OMPUnrollDirective *OMPUnrollDirective::Create(
452+
const ASTContext &C, SourceLocation StartLoc, SourceLocation EndLoc,
453+
ArrayRef<OMPClause *> Clauses, Stmt *AssociatedStmt,
454+
unsigned NumGeneratedTopLevelLoops, Stmt *TransformedStmt, Stmt *PreInits) {
455+
assert(NumGeneratedTopLevelLoops <= 1 &&
456+
"Unrolling generates at most one loop");
456457

457458
auto *Dir = createDirective<OMPUnrollDirective>(
458459
C, Clauses, AssociatedStmt, TransformedStmtOffset + 1, StartLoc, EndLoc);
459-
Dir->setNumGeneratedLoops(NumGeneratedLoops);
460+
Dir->setNumGeneratedTopLevelLoops(NumGeneratedTopLevelLoops);
460461
Dir->setTransformedStmt(TransformedStmt);
461462
Dir->setPreInits(PreInits);
462463
return Dir;

clang/lib/Sema/SemaOpenMP.cpp

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14983,12 +14983,13 @@ StmtResult SemaOpenMP::ActOnOpenMPUnrollDirective(ArrayRef<OMPClause *> Clauses,
1498314983
Body, OriginalInits))
1498414984
return StmtError();
1498514985

14986-
unsigned NumGeneratedLoops = PartialClause ? 1 : 0;
14986+
unsigned NumGeneratedTopLevelLoops = PartialClause ? 1 : 0;
1498714987

1498814988
// Delay unrolling to when template is completely instantiated.
1498914989
if (SemaRef.CurContext->isDependentContext())
1499014990
return OMPUnrollDirective::Create(Context, StartLoc, EndLoc, Clauses, AStmt,
14991-
NumGeneratedLoops, nullptr, nullptr);
14991+
NumGeneratedTopLevelLoops, nullptr,
14992+
nullptr);
1499214993

1499314994
assert(LoopHelpers.size() == NumLoops &&
1499414995
"Expecting a single-dimensional loop iteration space");
@@ -15011,9 +15012,10 @@ StmtResult SemaOpenMP::ActOnOpenMPUnrollDirective(ArrayRef<OMPClause *> Clauses,
1501115012
// The generated loop may only be passed to other loop-associated directive
1501215013
// when a partial clause is specified. Without the requirement it is
1501315014
// sufficient to generate loop unroll metadata at code-generation.
15014-
if (NumGeneratedLoops == 0)
15015+
if (NumGeneratedTopLevelLoops == 0)
1501515016
return OMPUnrollDirective::Create(Context, StartLoc, EndLoc, Clauses, AStmt,
15016-
NumGeneratedLoops, nullptr, nullptr);
15017+
NumGeneratedTopLevelLoops, nullptr,
15018+
nullptr);
1501715019

1501815020
// Otherwise, we need to provide a de-sugared/transformed AST that can be
1501915021
// associated with another loop directive.
@@ -15228,7 +15230,7 @@ StmtResult SemaOpenMP::ActOnOpenMPUnrollDirective(ArrayRef<OMPClause *> Clauses,
1522815230
LoopHelper.Init->getBeginLoc(), LoopHelper.Inc->getEndLoc());
1522915231

1523015232
return OMPUnrollDirective::Create(Context, StartLoc, EndLoc, Clauses, AStmt,
15231-
NumGeneratedLoops, OuterFor,
15233+
NumGeneratedTopLevelLoops, OuterFor,
1523215234
buildPreInits(Context, PreInits));
1523315235
}
1523415236

clang/lib/Serialization/ASTReaderStmt.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2450,7 +2450,7 @@ void ASTStmtReader::VisitOMPSimdDirective(OMPSimdDirective *D) {
24502450
void ASTStmtReader::VisitOMPCanonicalLoopNestTransformationDirective(
24512451
OMPCanonicalLoopNestTransformationDirective *D) {
24522452
VisitOMPLoopBasedDirective(D);
2453-
D->setNumGeneratedLoops(Record.readUInt32());
2453+
D->setNumGeneratedTopLevelLoops(Record.readUInt32());
24542454
}
24552455

24562456
void ASTStmtReader::VisitOMPTileDirective(OMPTileDirective *D) {

clang/lib/Serialization/ASTWriterStmt.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2459,7 +2459,7 @@ void ASTStmtWriter::VisitOMPSimdDirective(OMPSimdDirective *D) {
24592459
void ASTStmtWriter::VisitOMPCanonicalLoopNestTransformationDirective(
24602460
OMPCanonicalLoopNestTransformationDirective *D) {
24612461
VisitOMPLoopBasedDirective(D);
2462-
Record.writeUInt32(D->getNumGeneratedLoops());
2462+
Record.writeUInt32(D->getNumGeneratedTopLevelLoops());
24632463
}
24642464

24652465
void ASTStmtWriter::VisitOMPTileDirective(OMPTileDirective *D) {

0 commit comments

Comments
 (0)