Skip to content

Commit eccbf9e

Browse files
committed
[Flang][OpenMP] Make implicitly captured scalars fully firstprivatized
Currently, we indicate to the runtime that implicit scalar captures are firstprivate (via map and capture types), enough for the runtime trace to treat it as such, but we do not CodeGen the IR in such a way that we can take full advantage of this aspect of the OpenMP specification. This patch seeks to change that by applying the correct symbol flags (firstprivate/implicit) to the implicitly captured scalars within target regions, which then triggers the delayed privitization code generation for these symbols, bringing the code generation in-line with the explicit firstpriviate clause. Currently, similarly to the delayed privitization I have sheltered this segment of code behind the EnabledDelayedPrivitization flag, as without it, we'll trigger an compiler error for firstprivate not being supported any time we implicitly capture a scalar and try to firstprivitize it, in future when this flag is removed it can also be removed here. So, for now, you need to enable this via providing the compiler the flag on compilation of any programs.
1 parent fd8f69d commit eccbf9e

15 files changed

+429
-84
lines changed

flang/lib/Lower/OpenMP/DataSharingProcessor.cpp

Lines changed: 44 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -43,16 +43,42 @@ bool DataSharingProcessor::OMPConstructSymbolVisitor::isSymbolDefineBy(
4343
[](const auto &functionParserNode) { return false; }});
4444
}
4545

46+
static bool isConstructWithTopLevelTarget(lower::pft::Evaluation &eval) {
47+
const auto *ompEval = eval.getIf<parser::OpenMPConstruct>();
48+
if (ompEval) {
49+
auto dir = parser::omp::GetOmpDirectiveName(ompEval).v;
50+
switch (dir) {
51+
case llvm::omp::Directive::OMPD_target:
52+
case llvm::omp::Directive::OMPD_target_loop:
53+
case llvm::omp::Directive::OMPD_target_parallel_do:
54+
case llvm::omp::Directive::OMPD_target_parallel_do_simd:
55+
case llvm::omp::Directive::OMPD_target_parallel_loop:
56+
case llvm::omp::Directive::OMPD_target_teams_distribute:
57+
case llvm::omp::Directive::OMPD_target_teams_distribute_parallel_do:
58+
case llvm::omp::Directive::OMPD_target_teams_distribute_parallel_do_simd:
59+
case llvm::omp::Directive::OMPD_target_teams_distribute_simd:
60+
case llvm::omp::Directive::OMPD_target_teams_loop:
61+
case llvm::omp::Directive::OMPD_target_simd:
62+
return true;
63+
break;
64+
default:
65+
return false;
66+
break;
67+
}
68+
}
69+
return false;
70+
}
71+
4672
DataSharingProcessor::DataSharingProcessor(
4773
lower::AbstractConverter &converter, semantics::SemanticsContext &semaCtx,
4874
const List<Clause> &clauses, lower::pft::Evaluation &eval,
4975
bool shouldCollectPreDeterminedSymbols, bool useDelayedPrivatization,
50-
lower::SymMap &symTable)
76+
lower::SymMap &symTable, bool isTargetPrivitization)
5177
: converter(converter), semaCtx(semaCtx),
5278
firOpBuilder(converter.getFirOpBuilder()), clauses(clauses), eval(eval),
5379
shouldCollectPreDeterminedSymbols(shouldCollectPreDeterminedSymbols),
5480
useDelayedPrivatization(useDelayedPrivatization), symTable(symTable),
55-
visitor(semaCtx) {
81+
isTargetPrivitization(isTargetPrivitization), visitor(semaCtx) {
5682
eval.visit([&](const auto &functionParserNode) {
5783
parser::Walk(functionParserNode, visitor);
5884
});
@@ -62,10 +88,12 @@ DataSharingProcessor::DataSharingProcessor(lower::AbstractConverter &converter,
6288
semantics::SemanticsContext &semaCtx,
6389
lower::pft::Evaluation &eval,
6490
bool useDelayedPrivatization,
65-
lower::SymMap &symTable)
91+
lower::SymMap &symTable,
92+
bool isTargetPrivitization)
6693
: DataSharingProcessor(converter, semaCtx, {}, eval,
6794
/*shouldCollectPreDeterminedSymols=*/false,
68-
useDelayedPrivatization, symTable) {}
95+
useDelayedPrivatization, symTable,
96+
isTargetPrivitization) {}
6997

7098
void DataSharingProcessor::processStep1(
7199
mlir::omp::PrivateClauseOps *clauseOps) {
@@ -552,8 +580,19 @@ void DataSharingProcessor::collectSymbols(
552580
};
553581

554582
auto shouldCollectSymbol = [&](const semantics::Symbol *sym) {
555-
if (collectImplicit)
583+
if (collectImplicit) {
584+
// If we're a combined construct with a target region, implicit
585+
// firstprivate captures, should only belong to the target region
586+
// and not be added/captured by later directives. Parallel regions
587+
// will likely want the same captures to be shared and for SIMD it's
588+
// illegal to have firstprivate clauses.
589+
if (isConstructWithTopLevelTarget(eval) && !isTargetPrivitization &&
590+
sym->test(semantics::Symbol::Flag::OmpFirstPrivate)) {
591+
return false;
592+
}
593+
556594
return sym->test(semantics::Symbol::Flag::OmpImplicit);
595+
}
557596

558597
if (collectPreDetermined)
559598
return sym->test(semantics::Symbol::Flag::OmpPreDetermined);

flang/lib/Lower/OpenMP/DataSharingProcessor.h

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,7 @@ class DataSharingProcessor {
9393
bool useDelayedPrivatization;
9494
llvm::SmallSet<const semantics::Symbol *, 16> mightHaveReadHostSym;
9595
lower::SymMap &symTable;
96+
bool isTargetPrivitization;
9697
OMPConstructSymbolVisitor visitor;
9798

9899
bool needBarrier();
@@ -130,12 +131,14 @@ class DataSharingProcessor {
130131
const List<Clause> &clauses,
131132
lower::pft::Evaluation &eval,
132133
bool shouldCollectPreDeterminedSymbols,
133-
bool useDelayedPrivatization, lower::SymMap &symTable);
134+
bool useDelayedPrivatization, lower::SymMap &symTable,
135+
bool isTargetPrivitization = false);
134136

135137
DataSharingProcessor(lower::AbstractConverter &converter,
136138
semantics::SemanticsContext &semaCtx,
137139
lower::pft::Evaluation &eval,
138-
bool useDelayedPrivatization, lower::SymMap &symTable);
140+
bool useDelayedPrivatization, lower::SymMap &symTable,
141+
bool isTargetPrivitization = false);
139142

140143
// Privatisation is split into two steps.
141144
// Step1 performs cloning of all privatisation clauses and copying for

flang/lib/Lower/OpenMP/OpenMP.cpp

Lines changed: 38 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -2470,6 +2470,36 @@ genSingleOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
24702470
queue, item, clauseOps);
24712471
}
24722472

2473+
static bool isDuplicateMappedSymbol(
2474+
const semantics::Symbol &sym,
2475+
const llvm::SetVector<const semantics::Symbol *> &privatizedSyms,
2476+
const llvm::SmallVectorImpl<const semantics::Symbol *> &hasDevSyms,
2477+
const llvm::SmallVectorImpl<const semantics::Symbol *> &mappedSyms) {
2478+
llvm::SmallVector<const semantics::Symbol *> concatSyms;
2479+
concatSyms.reserve(privatizedSyms.size() + hasDevSyms.size() +
2480+
mappedSyms.size());
2481+
concatSyms.append(privatizedSyms.begin(), privatizedSyms.end());
2482+
concatSyms.append(hasDevSyms.begin(), hasDevSyms.end());
2483+
concatSyms.append(mappedSyms.begin(), mappedSyms.end());
2484+
2485+
auto checkSymbol = [&](const semantics::Symbol &checkSym) {
2486+
if (llvm::is_contained(concatSyms, &checkSym))
2487+
return true;
2488+
2489+
return std::any_of(concatSyms.begin(), concatSyms.end(),
2490+
[&](auto v) { return v->GetUltimate() == checkSym; });
2491+
};
2492+
2493+
if (checkSymbol(sym))
2494+
return true;
2495+
2496+
const auto *hostAssoc{sym.detailsIf<semantics::HostAssocDetails>()};
2497+
if (hostAssoc && checkSymbol(hostAssoc->symbol()))
2498+
return true;
2499+
2500+
return checkSymbol(sym.GetUltimate());
2501+
}
2502+
24732503
static mlir::omp::TargetOp
24742504
genTargetOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
24752505
lower::StatementContext &stmtCtx,
@@ -2496,7 +2526,8 @@ genTargetOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
24962526
DataSharingProcessor dsp(converter, semaCtx, item->clauses, eval,
24972527
/*shouldCollectPreDeterminedSymbols=*/
24982528
lower::omp::isLastItemInQueue(item, queue),
2499-
/*useDelayedPrivatization=*/true, symTable);
2529+
/*useDelayedPrivatization=*/true, symTable,
2530+
/*isTargetPrivitization=*/true);
25002531
dsp.processStep1(&clauseOps);
25012532

25022533
// 5.8.1 Implicit Data-Mapping Attribute Rules
@@ -2505,17 +2536,6 @@ genTargetOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
25052536
// attribute clauses (neither data-sharing; e.g. `private`, nor `map`
25062537
// clauses).
25072538
auto captureImplicitMap = [&](const semantics::Symbol &sym) {
2508-
if (dsp.getAllSymbolsToPrivatize().contains(&sym))
2509-
return;
2510-
2511-
// Skip parameters/constants as they do not need to be mapped.
2512-
if (semantics::IsNamedConstant(sym))
2513-
return;
2514-
2515-
// These symbols are mapped individually in processHasDeviceAddr.
2516-
if (llvm::is_contained(hasDeviceAddrSyms, &sym))
2517-
return;
2518-
25192539
// Structure component symbols don't have bindings, and can only be
25202540
// explicitly mapped individually. If a member is captured implicitly
25212541
// we map the entirety of the derived type when we find its symbol.
@@ -2537,7 +2557,12 @@ genTargetOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
25372557
if (!converter.getSymbolAddress(sym))
25382558
return;
25392559

2540-
if (!llvm::is_contained(mapSyms, &sym)) {
2560+
// Skip parameters/constants as they do not need to be mapped.
2561+
if (semantics::IsNamedConstant(sym))
2562+
return;
2563+
2564+
if (!isDuplicateMappedSymbol(sym, dsp.getAllSymbolsToPrivatize(),
2565+
hasDeviceAddrSyms, mapSyms)) {
25412566
if (const auto *details =
25422567
sym.template detailsIf<semantics::HostAssocDetails>())
25432568
converter.copySymbolBinding(details->symbol(), sym);

flang/lib/Lower/Support/Utils.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212

1313
#include "flang/Lower/Support/Utils.h"
1414

15+
#include "flang/Common/idioms.h"
1516
#include "flang/Common/indirection.h"
1617
#include "flang/Lower/AbstractConverter.h"
1718
#include "flang/Lower/ConvertVariable.h"

flang/lib/Optimizer/OpenMP/MapsForPrivatizedSymbols.cpp

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,20 @@ class MapsForPrivatizedSymbolsPass
5555

5656
omp::MapInfoOp createMapInfo(Location loc, Value var,
5757
fir::FirOpBuilder &builder) {
58+
// Check if a value of type `type` can be passed to the kernel by value.
59+
// All kernel parameters are of pointer type, so if the value can be
60+
// represented inside of a pointer, then it can be passed by value.
61+
auto isLiteralType = [&](mlir::Type type) {
62+
const mlir::DataLayout &dl = builder.getDataLayout();
63+
mlir::Type ptrTy = mlir::LLVM::LLVMPointerType::get(builder.getContext());
64+
uint64_t ptrSize = dl.getTypeSize(ptrTy);
65+
uint64_t ptrAlign = dl.getTypePreferredAlignment(ptrTy);
66+
67+
auto [size, align] = fir::getTypeSizeAndAlignmentOrCrash(
68+
loc, type, dl, builder.getKindMap());
69+
return size <= ptrSize && align <= ptrAlign;
70+
};
71+
5872
uint64_t mapTypeTo = static_cast<
5973
std::underlying_type_t<llvm::omp::OpenMPOffloadMappingFlags>>(
6074
llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO);
@@ -94,14 +108,22 @@ class MapsForPrivatizedSymbolsPass
94108
if (needsBoundsOps(varPtr))
95109
genBoundsOps(builder, varPtr, boundsOps);
96110

111+
mlir::omp::VariableCaptureKind captureKind =
112+
mlir::omp::VariableCaptureKind::ByRef;
113+
if (fir::isa_trivial(fir::unwrapRefType(varPtr.getType())) ||
114+
fir::isa_char(fir::unwrapRefType(varPtr.getType()))) {
115+
if (isLiteralType(fir::unwrapRefType(varPtr.getType()))) {
116+
captureKind = mlir::omp::VariableCaptureKind::ByCopy;
117+
}
118+
}
119+
97120
return builder.create<omp::MapInfoOp>(
98121
loc, varPtr.getType(), varPtr,
99122
TypeAttr::get(llvm::cast<omp::PointerLikeType>(varPtr.getType())
100123
.getElementType()),
101124
builder.getIntegerAttr(builder.getIntegerType(64, /*isSigned=*/false),
102125
mapTypeTo),
103-
builder.getAttr<omp::VariableCaptureKindAttr>(
104-
omp::VariableCaptureKind::ByRef),
126+
builder.getAttr<omp::VariableCaptureKindAttr>(captureKind),
105127
/*varPtrPtr=*/Value{},
106128
/*members=*/SmallVector<Value>{},
107129
/*member_index=*/mlir::ArrayAttr{},

0 commit comments

Comments
 (0)