Skip to content

Commit aa2b629

Browse files
committed
[SYCL][Fusion] Rebase and address feedback
Signed-off-by: Lukas Sommer <[email protected]>
1 parent 26978b2 commit aa2b629

File tree

20 files changed

+412
-350
lines changed

20 files changed

+412
-350
lines changed

clang/include/clang/Driver/Action.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -665,7 +665,7 @@ class OffloadWrapperJobAction : public JobAction {
665665
public:
666666
OffloadWrapperJobAction(ActionList &Inputs, types::ID Type);
667667
OffloadWrapperJobAction(Action *Input, types::ID OutputType,
668-
bool IsEmbeddedIR = false);
668+
bool EmbedIR = false);
669669

670670
bool isEmbeddedIR() const { return EmbedIR; }
671671

clang/lib/Driver/Action.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -478,7 +478,7 @@ void OffloadWrapperJobAction::anchor() {}
478478

479479
OffloadWrapperJobAction::OffloadWrapperJobAction(ActionList &Inputs,
480480
types::ID Type)
481-
: JobAction(OffloadWrapperJobClass, Inputs, Type) {}
481+
: JobAction(OffloadWrapperJobClass, Inputs, Type), EmbedIR(false) {}
482482

483483
OffloadWrapperJobAction::OffloadWrapperJobAction(Action *Input, types::ID Type,
484484
bool IsEmbeddedIR)

sycl-fusion/jit-compiler/CMakeLists.txt

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,8 @@ add_llvm_library(sycl-fusion
2222
InstCombine
2323
Target
2424
TargetParser
25-
NVPTX
26-
X86
2725
MC
26+
${LLVM_TARGETS_TO_BUILD}
2827
)
2928

3029
target_include_directories(sycl-fusion
@@ -47,6 +46,10 @@ target_link_libraries(sycl-fusion
4746
${CMAKE_THREAD_LIBS_INIT}
4847
)
4948

49+
if("NVPTX" IN_LIST LLVM_TARGETS_TO_BUILD)
50+
target_compile_definitions(sycl-fusion PRIVATE FUSION_JIT_SUPPORT_PTX)
51+
endif()
52+
5053
if (BUILD_SHARED_LIBS)
5154
if(NOT MSVC AND NOT APPLE)
5255
# Manage symbol visibility through the linker to make sure no LLVM symbols

sycl-fusion/jit-compiler/include/JITContext.h

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ using CacheKeyT =
3939
/// Wrapper around a kernel binary.
4040
class KernelBinary {
4141
public:
42-
explicit KernelBinary(std::string Binary, BinaryFormat Format);
42+
explicit KernelBinary(std::string &&Binary, BinaryFormat Format);
4343

4444
jit_compiler::BinaryAddress address() const;
4545

@@ -65,7 +65,10 @@ class JITContext {
6565

6666
llvm::LLVMContext *getLLVMContext();
6767

68-
KernelBinary &emplaceSPIRVBinary(std::string Binary, BinaryFormat Format);
68+
template <typename... Ts> KernelBinary &emplaceKernelBinary(Ts &&...Args) {
69+
WriteLockT WriteLock{BinariesMutex};
70+
return Binaries.emplace_back(std::forward<Ts>(Args)...);
71+
}
6972

7073
std::optional<SYCLKernelInfo> getCacheEntry(CacheKeyT &Identifier) const;
7174

sycl-fusion/jit-compiler/include/Options.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
#define SYCL_FUSION_JIT_COMPILER_OPTIONS_H
1111

1212
#include "Kernel.h"
13+
1314
#include <memory>
1415
#include <unordered_map>
1516

sycl-fusion/jit-compiler/lib/JITContext.cpp

Lines changed: 1 addition & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111

1212
using namespace jit_compiler;
1313

14-
KernelBinary::KernelBinary(std::string Binary, BinaryFormat Fmt)
14+
KernelBinary::KernelBinary(std::string &&Binary, BinaryFormat Fmt)
1515
: Blob{std::move(Binary)}, Format{Fmt} {}
1616

1717
jit_compiler::BinaryAddress KernelBinary::address() const {
@@ -29,15 +29,6 @@ JITContext::~JITContext() = default;
2929

3030
llvm::LLVMContext *JITContext::getLLVMContext() { return LLVMCtx.get(); }
3131

32-
KernelBinary &JITContext::emplaceSPIRVBinary(std::string Binary,
33-
BinaryFormat Format) {
34-
WriteLockT WriteLock{BinariesMutex};
35-
// NOTE: With C++17, which returns a reference from emplace_back, the
36-
// following code would be even simpler.
37-
Binaries.emplace_back(std::move(Binary), Format);
38-
return Binaries.back();
39-
}
40-
4132
std::optional<SYCLKernelInfo>
4233
JITContext::getCacheEntry(CacheKeyT &Identifier) const {
4334
ReadLockT ReadLock{CacheMutex};

sycl-fusion/jit-compiler/lib/KernelFusion.cpp

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,22 @@ gatherNDRanges(llvm::ArrayRef<SYCLKernelInfo> KernelInformation) {
4848
return NDRanges;
4949
}
5050

51+
static bool isTargetFormatSupported(BinaryFormat TargetFormat) {
52+
switch (TargetFormat) {
53+
case BinaryFormat::SPIRV:
54+
return true;
55+
case BinaryFormat::PTX: {
56+
#ifdef FUSION_JIT_SUPPORT_PTX
57+
return true;
58+
#else // FUSION_JIT_SUPPORT_PTX
59+
return false;
60+
#endif // FUSION_JIT_SUPPORT_PTX
61+
}
62+
default:
63+
return false;
64+
}
65+
}
66+
5167
FusionResult KernelFusion::fuseKernels(
5268
JITContext &JITCtx, Config &&JITConfig,
5369
const std::vector<SYCLKernelInfo> &KernelInformation,
@@ -71,6 +87,12 @@ FusionResult KernelFusion::fuseKernels(
7187
bool IsHeterogeneousList = jit_compiler::isHeterogeneousList(NDRanges);
7288

7389
BinaryFormat TargetFormat = ConfigHelper::get<option::JITTargetFormat>();
90+
91+
if (!isTargetFormatSupported(TargetFormat)) {
92+
return FusionResult(
93+
"Fusion output target format not supported by this build");
94+
}
95+
7496
if (TargetFormat == BinaryFormat::PTX && IsHeterogeneousList) {
7597
return FusionResult{"Heterogeneous ND ranges not supported for CUDA"};
7698
}

sycl-fusion/jit-compiler/lib/fusion/FusionPipeline.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,8 @@
2323
#ifndef NDEBUG
2424
#include "llvm/IR/Verifier.h"
2525
#endif // NDEBUG
26-
#include "llvm/ADT/Triple.h"
2726
#include "llvm/Passes/PassBuilder.h"
27+
#include "llvm/TargetParser/Triple.h"
2828
#include "llvm/Transforms/InstCombine/InstCombine.h"
2929
#include "llvm/Transforms/Scalar/ADCE.h"
3030
#include "llvm/Transforms/Scalar/EarlyCSE.h"
@@ -103,7 +103,7 @@ FusionPipeline::runFusionPasses(Module &Mod, SYCLModuleInfo &InputInfo,
103103
// to/from generic address-space as possible, because these hinder
104104
// internalization.
105105
// Ideally, the static compiler should have performed that job.
106-
unsigned FlatAddressSpace = getFlatAddressSpace(Mod);
106+
const unsigned FlatAddressSpace = getFlatAddressSpace(Mod);
107107
FPM.addPass(InferAddressSpacesPass(FlatAddressSpace));
108108
MPM.addPass(createModuleToFunctionPassAdaptor(std::move(FPM)));
109109
}

sycl-fusion/jit-compiler/lib/fusion/ModuleHelper.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ helper::ModuleHelper::cloneAndPruneModule(Module *Mod,
2424
identifyUnusedFunctions(Mod, CGRoots, UnusedFunctions);
2525

2626
{
27-
auto TFI = llvm::TargetFusionInfo::getTargetFusionInfo(Mod);
27+
TargetFusionInfo TFI{Mod};
2828
SmallVector<Function *> Unused{UnusedFunctions.begin(),
2929
UnusedFunctions.end()};
3030
TFI.notifyFunctionsDelete(Unused);

sycl-fusion/jit-compiler/lib/translation/KernelTranslation.cpp

Lines changed: 33 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
//===----------------------------------------------------------------------===//
88

99
#include "KernelTranslation.h"
10+
1011
#include "SPIRVLLVMTranslation.h"
1112
#include "llvm/Bitcode/BitcodeReader.h"
1213
#include "llvm/IR/Constants.h"
@@ -182,7 +183,8 @@ llvm::Error KernelTranslator::translateKernel(SYCLKernelInfo &Kernel,
182183
break;
183184
}
184185
case BinaryFormat::PTX: {
185-
llvm::Expected<KernelBinary *> BinaryOrError = translateToPTX(Mod, JITCtx);
186+
llvm::Expected<KernelBinary *> BinaryOrError =
187+
translateToPTX(Kernel, Mod, JITCtx);
186188
if (auto Error = BinaryOrError.takeError()) {
187189
return Error;
188190
}
@@ -215,12 +217,20 @@ KernelTranslator::translateToSPIRV(llvm::Module &Mod, JITContext &JITCtx) {
215217
}
216218

217219
llvm::Expected<KernelBinary *>
218-
KernelTranslator::translateToPTX(llvm::Module &Mod, JITContext &JITCtx) {
219-
// FIXME: Can we limit this to the NVPTX specific target?
220-
llvm::InitializeAllTargets();
221-
llvm::InitializeAllAsmParsers();
222-
llvm::InitializeAllAsmPrinters();
223-
llvm::InitializeAllTargetMCs();
220+
KernelTranslator::translateToPTX(SYCLKernelInfo &KernelInfo, llvm::Module &Mod,
221+
JITContext &JITCtx) {
222+
#ifndef FUSION_JIT_SUPPORT_PTX
223+
return createStringError(inconvertibleErrorCode(),
224+
"PTX translation not supported in this build");
225+
#else // FUSION_JIT_SUPPORT_PTX
226+
LLVMInitializeNVPTXTargetInfo();
227+
LLVMInitializeNVPTXTarget();
228+
LLVMInitializeNVPTXAsmPrinter();
229+
LLVMInitializeNVPTXTargetMC();
230+
#endif // FUSION_JIT_SUPPORT_PTX
231+
232+
static const char *TARGET_CPU_ATTRIBUTE = "target-cpu";
233+
static const char *TARGET_FEATURE_ATTRIBUTE = "target-features";
224234

225235
std::string TargetTriple{"nvptx64-nvidia-cuda"};
226236

@@ -231,13 +241,26 @@ KernelTranslator::translateToPTX(llvm::Module &Mod, JITContext &JITCtx) {
231241
if (!Target) {
232242
return createStringError(
233243
inconvertibleErrorCode(),
234-
"Failed to load and translate SPIR-V module with error %s",
244+
"Failed to load and translate PTX LLVM IR module with error %s",
235245
ErrorMessage.c_str());
236246
}
237247

248+
llvm::StringRef TargetCPU{"sm_50"};
249+
llvm::StringRef TargetFeatures{"+sm_50,+ptx76"};
250+
if (auto *KernelFunc = Mod.getFunction(KernelInfo.Name)) {
251+
if (KernelFunc->hasFnAttribute(TARGET_CPU_ATTRIBUTE)) {
252+
TargetCPU =
253+
KernelFunc->getFnAttribute(TARGET_CPU_ATTRIBUTE).getValueAsString();
254+
}
255+
if (KernelFunc->hasFnAttribute(TARGET_FEATURE_ATTRIBUTE)) {
256+
TargetFeatures = KernelFunc->getFnAttribute(TARGET_FEATURE_ATTRIBUTE)
257+
.getValueAsString();
258+
}
259+
}
260+
238261
// FIXME: Check whether we can provide more accurate target information here
239262
auto *TargetMachine = Target->createTargetMachine(
240-
TargetTriple, "sm_50", "+sm_50,+ptx76", {}, llvm::Reloc::PIC_,
263+
TargetTriple, TargetCPU, TargetFeatures, {}, llvm::Reloc::PIC_,
241264
std::nullopt, llvm::CodeGenOpt::Default);
242265

243266
llvm::legacy::PassManager PM;
@@ -259,5 +282,5 @@ KernelTranslator::translateToPTX(llvm::Module &Mod, JITContext &JITCtx) {
259282
ASMStream.flush();
260283
}
261284

262-
return &JITCtx.emplaceSPIRVBinary(PTXASM, BinaryFormat::PTX);
285+
return &JITCtx.emplaceKernelBinary(std::move(PTXASM), BinaryFormat::PTX);
263286
}

0 commit comments

Comments
 (0)