Skip to content

[OpenMP][OMPIRBuilder] Use device shared memory for arg structures #150925

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: users/skatrak/flang-generic-03-mlir-shared-mem
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 15 additions & 2 deletions llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@

namespace llvm {
class CanonicalLoopInfo;
class CodeExtractor;
class ScanInfo;
struct TargetRegionEntryInfo;
class OffloadEntriesInfoManager;
Expand Down Expand Up @@ -2248,17 +2249,27 @@ class OpenMPIRBuilder {
BasicBlock *EntryBB, *ExitBB, *OuterAllocaBB;
SmallVector<Value *, 2> ExcludeArgsFromAggregate;

LLVM_ABI virtual ~OutlineInfo() = default;

/// Collect all blocks in between EntryBB and ExitBB in both the given
/// vector and set.
LLVM_ABI void collectBlocks(SmallPtrSetImpl<BasicBlock *> &BlockSet,
SmallVectorImpl<BasicBlock *> &BlockVector);

/// Create a CodeExtractor instance based on the information stored in this
/// structure, the list of collected blocks from a previous call to
/// \c collectBlocks and a flag stating whether arguments must be passed in
/// address space 0.
LLVM_ABI virtual std::unique_ptr<CodeExtractor>
createCodeExtractor(ArrayRef<BasicBlock *> Blocks,
bool ArgsInZeroAddressSpace, Twine Suffix = Twine(""));

/// Return the function that contains the region to be outlined.
Function *getFunction() const { return EntryBB->getParent(); }
};

/// Collection of regions that need to be outlined during finalization.
SmallVector<OutlineInfo, 16> OutlineInfos;
SmallVector<std::unique_ptr<OutlineInfo>, 16> OutlineInfos;

/// A collection of candidate target functions that's constant allocas will
/// attempt to be raised on a call of finalize after all currently enqueued
Expand All @@ -2273,7 +2284,9 @@ class OpenMPIRBuilder {
std::forward_list<ScanInfo> ScanInfos;

/// Add a new region that will be outlined later.
void addOutlineInfo(OutlineInfo &&OI) { OutlineInfos.emplace_back(OI); }
void addOutlineInfo(std::unique_ptr<OutlineInfo> &&OI) {
OutlineInfos.emplace_back(std::move(OI));
}

/// An ordered map of auto-generated variables to their unique names.
/// It stores variables with the following names: 1) ".gomp_critical_user_" +
Expand Down
51 changes: 39 additions & 12 deletions llvm/include/llvm/Transforms/Utils/CodeExtractor.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,15 @@
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/SetVector.h"
#include "llvm/IR/BasicBlock.h"
#include "llvm/Support/Compiler.h"
#include <limits>

namespace llvm {

template <typename PtrType> class SmallPtrSetImpl;
class AddrSpaceCastInst;
class AllocaInst;
class BasicBlock;
class BlockFrequency;
class BlockFrequencyInfo;
class BranchProbabilityInfo;
Expand Down Expand Up @@ -94,15 +95,23 @@ class CodeExtractorAnalysisCache {
BranchProbabilityInfo *BPI;
AssumptionCache *AC;

// A block outside of the extraction set where any intermediate
// allocations will be placed inside. If this is null, allocations
// will be placed in the entry block of the function.
/// A block outside of the extraction set where any intermediate
/// allocations will be placed inside. If this is null, allocations
/// will be placed in the entry block of the function.
BasicBlock *AllocationBlock;

// If true, varargs functions can be extracted.
/// A block outside of the extraction set where deallocations for
/// intermediate allocations can be placed inside. Not used for
/// automatically deallocated memory (e.g. `alloca`), which is the default.
///
/// If it is null and needed, the end of the replacement basic block will be
/// used to place deallocations.
BasicBlock *DeallocationBlock;

/// If true, varargs functions can be extracted.
bool AllowVarArgs;

// Bits of intermediate state computed at various phases of extraction.
/// Bits of intermediate state computed at various phases of extraction.
SetVector<BasicBlock *> Blocks;

/// Lists of blocks that are branched from the code region to be extracted,
Expand All @@ -124,13 +133,13 @@ class CodeExtractorAnalysisCache {
/// returns 1, etc.
SmallVector<BasicBlock *> ExtractedFuncRetVals;

// Suffix to use when creating extracted function (appended to the original
// function name + "."). If empty, the default is to use the entry block
// label, if non-empty, otherwise "extracted".
/// Suffix to use when creating extracted function (appended to the original
/// function name + "."). If empty, the default is to use the entry block
/// label, if non-empty, otherwise "extracted".
std::string Suffix;

// If true, the outlined function has aggregate argument in zero address
// space.
/// If true, the outlined function has aggregate argument in zero address
/// space.
bool ArgsInZeroAddressSpace;

public:
Expand All @@ -146,7 +155,9 @@ class CodeExtractorAnalysisCache {
/// however code extractor won't validate whether extraction is legal.
/// Any new allocations will be placed in the AllocationBlock, unless
/// it is null, in which case it will be placed in the entry block of
/// the function from which the code is being extracted.
/// the function from which the code is being extracted. Explicit
/// deallocations for the aforementioned allocations will be placed in the
/// DeallocationBlock or the end of the replacement block, if needed.
/// If ArgsInZeroAddressSpace param is set to true, then the aggregate
/// param pointer of the outlined function is declared in zero address
/// space.
Expand All @@ -157,8 +168,11 @@ class CodeExtractorAnalysisCache {
AssumptionCache *AC = nullptr, bool AllowVarArgs = false,
bool AllowAlloca = false,
BasicBlock *AllocationBlock = nullptr,
BasicBlock *DeallocationBlock = nullptr,
std::string Suffix = "", bool ArgsInZeroAddressSpace = false);

LLVM_ABI virtual ~CodeExtractor() = default;

/// Perform the extraction, returning the new function.
///
/// Returns zero when called on a CodeExtractor instance where isEligible
Expand Down Expand Up @@ -243,6 +257,19 @@ class CodeExtractorAnalysisCache {
/// region, passing it instead as a scalar.
LLVM_ABI void excludeArgFromAggregate(Value *Arg);

protected:
/// Allocate an intermediate variable at the specified point.
LLVM_ABI virtual Instruction *
allocateVar(BasicBlock *BB, BasicBlock::iterator AllocIP, Type *VarType,
const Twine &Name = Twine(""),
AddrSpaceCastInst **CastedAlloc = nullptr);

/// Deallocate a previously-allocated intermediate variable at the specified
/// point.
LLVM_ABI virtual Instruction *deallocateVar(BasicBlock *BB,
BasicBlock::iterator DeallocIP,
Value *Var, Type *VarType);

private:
struct LifetimeMarkerInfo {
bool SinkLifeStart = false;
Expand Down
Loading