From 1e925aa7af3ba333c89c04e6a10b3c307f701dc1 Mon Sep 17 00:00:00 2001 From: Takafumi Arakaki Date: Sun, 25 Oct 2020 14:23:19 -0700 Subject: [PATCH] Add a new API: llvm::TargetLibraryInfoImpl::setTapirTarget(factory) --- llvm/include/llvm/Analysis/TargetLibraryInfo.h | 6 ++++++ llvm/include/llvm/Transforms/Tapir/LoweringUtils.h | 3 +++ llvm/include/llvm/Transforms/Tapir/TapirTargetIDs.h | 1 + llvm/lib/Transforms/Tapir/LoweringUtils.cpp | 9 +++++++++ 4 files changed, 19 insertions(+) diff --git a/llvm/include/llvm/Analysis/TargetLibraryInfo.h b/llvm/include/llvm/Analysis/TargetLibraryInfo.h index fb6934fc65bf..922807e72ef0 100644 --- a/llvm/include/llvm/Analysis/TargetLibraryInfo.h +++ b/llvm/include/llvm/Analysis/TargetLibraryInfo.h @@ -9,6 +9,7 @@ #ifndef LLVM_ANALYSIS_TARGETLIBRARYINFO_H #define LLVM_ANALYSIS_TARGETLIBRARYINFO_H +#include #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/Optional.h" #include "llvm/ADT/Triple.h" @@ -18,6 +19,7 @@ #include "llvm/IR/PassManager.h" #include "llvm/Pass.h" #include "llvm/Transforms/Tapir/TapirTargetIDs.h" +#include "llvm/Transforms/Tapir/LoweringUtils.h" namespace llvm { template class ArrayRef; @@ -38,6 +40,8 @@ struct VecDesc { NumLibFuncs }; +using TapirTargetFactory = std::function; + /// Implementation of the target library information. /// /// This class constructs tables that hold the target library information and @@ -204,6 +208,8 @@ class TargetLibraryInfoImpl { TapirTarget = TargetID; } + void setTapirTarget(TapirTargetFactory target); + /// Return the ID of the target for Tapir lowering. TapirTargetID getTapirTarget() const { return TapirTarget; diff --git a/llvm/include/llvm/Transforms/Tapir/LoweringUtils.h b/llvm/include/llvm/Transforms/Tapir/LoweringUtils.h index fc2bb13832ab..56953bad9ed5 100644 --- a/llvm/include/llvm/Transforms/Tapir/LoweringUtils.h +++ b/llvm/include/llvm/Transforms/Tapir/LoweringUtils.h @@ -13,6 +13,7 @@ #ifndef LOWERING_UTILS_H_ #define LOWERING_UTILS_H_ +#include #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/MapVector.h" #include "llvm/ADT/SetVector.h" @@ -409,6 +410,8 @@ class LoopOutlineProcessor { /// Generate a TapirTarget object for the specified TapirTargetID. TapirTarget *getTapirTargetFromID(Module &M, TapirTargetID TargetID); +void setCustomTapirTarget(std::function); + /// Find all inputs to tasks within a function \p F, including nested tasks. TaskValueSetMap findAllTaskInputs(Function &F, const DominatorTree &DT, const TaskInfo &TI); diff --git a/llvm/include/llvm/Transforms/Tapir/TapirTargetIDs.h b/llvm/include/llvm/Transforms/Tapir/TapirTargetIDs.h index 49fbcd3423b3..6c4b2c4fb680 100644 --- a/llvm/include/llvm/Transforms/Tapir/TapirTargetIDs.h +++ b/llvm/include/llvm/Transforms/Tapir/TapirTargetIDs.h @@ -25,6 +25,7 @@ enum class TapirTargetID { OpenCilk, // Lower to OpenCilk ABI OpenMP, // Lower to OpenMP Qthreads, // Lower to Qthreads + Custom, Last_TapirTargetID }; diff --git a/llvm/lib/Transforms/Tapir/LoweringUtils.cpp b/llvm/lib/Transforms/Tapir/LoweringUtils.cpp index 2dc6ca72d66e..54fe4668c762 100644 --- a/llvm/lib/Transforms/Tapir/LoweringUtils.cpp +++ b/llvm/lib/Transforms/Tapir/LoweringUtils.cpp @@ -36,6 +36,13 @@ using namespace llvm; static const char TimerGroupName[] = DEBUG_TYPE; static const char TimerGroupDescription[] = "Tapir lowering"; +static TapirTargetFactory CUSTOM_TAPIR_TARGET; + +void llvm::TargetLibraryInfoImpl::setTapirTarget(TapirTargetFactory target) { + TapirTarget = TapirTargetID::Custom; + CUSTOM_TAPIR_TARGET = target; +} + TapirTarget *llvm::getTapirTargetFromID(Module &M, TapirTargetID ID) { switch (ID) { case TapirTargetID::None: @@ -53,6 +60,8 @@ TapirTarget *llvm::getTapirTargetFromID(Module &M, TapirTargetID ID) { return new OpenMPABI(M); case TapirTargetID::Qthreads: return new QthreadsABI(M); + case TapirTargetID::Custom: + return CUSTOM_TAPIR_TARGET(M); default: llvm_unreachable("Invalid TapirTargetID"); }