Skip to content

Commit 1313bbf

Browse files
committed
Basic 1D-only implementation of coalescing
1 parent bfaed2b commit 1313bbf

File tree

1 file changed

+267
-1
lines changed

1 file changed

+267
-1
lines changed

llvm/lib/Target/NVPTX/NVPTXMemOpts.cpp

Lines changed: 267 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
#include "llvm/Pass.h"
2222
#include "llvm/Support/Debug.h"
2323
#include "llvm/Transforms/Utils/BasicBlockUtils.h"
24+
#include "llvm/IR/IntrinsicsNVPTX.h"
2425

2526
#define DEBUG_TYPE "nvptx-mem-opts"
2627

@@ -41,15 +42,280 @@ namespace {
4142
StringRef getPassName() const override {
4243
return "Memory coalescing and prefetching";
4344
}
45+
46+
static std::string SYNC_THREADS_INTRINSIC_NAME;
47+
static std::string NVVM_READ_SREG_INTRINSIC_NAME;
48+
49+
enum IndexType {
50+
CONSTANT,
51+
ABSOLUTE_THREAD_ID,
52+
LOOP_INDUCTION
53+
};
54+
private:
55+
56+
// Helper functions
57+
void CoalesceMemCalls(LoadInst *LI, std::vector<IndexType> &indexValues);
58+
bool isCallCoalescable(LoadInst *LI, std::vector<IndexType> &indexValues);
59+
60+
std::vector<IndexType> isLoadingFromArray(LoadInst *LI);
61+
62+
Module *M;
63+
};
4464
};
4565

4666
char NVPTXMemOpts::ID = 0;
67+
std::string NVPTXMemOpts::SYNC_THREADS_INTRINSIC_NAME = "llvm.nvvm.barrier0";
68+
std::string NVPTXMemOpts::NVVM_READ_SREG_INTRINSIC_NAME = "llvm.nvvm.read.ptx.sreg";
69+
70+
// A common pattern to calculate the abosolute index of a thread is:
71+
// idx = tid + ctaid * ntid
72+
// This function will check if the index is calculated in this way
73+
bool isAbsoluteThreadIndex(Value *idx) {
74+
auto sext = dyn_cast<SExtInst>(idx);
75+
if (!sext) { return false; }
76+
77+
auto val = sext->getOperand(0);
78+
auto add = dyn_cast<BinaryOperator>(val);
79+
80+
if (!add || add->getOpcode() != Instruction::Add) { return false; }
81+
82+
auto mul = dyn_cast<BinaryOperator>(add->getOperand(0));
83+
if (!mul || mul->getOpcode() != Instruction::Mul) { return false; }
84+
85+
auto tid = dyn_cast<CallInst>(mul->getOperand(0));
86+
auto ntid = dyn_cast<CallInst>(mul->getOperand(1));
87+
auto ctaid = dyn_cast<CallInst>(add->getOperand(1));
88+
89+
if (!tid || !ntid || !ctaid) { return false; }
90+
91+
return true;
92+
}
93+
94+
/*
95+
This function is quite complicated because we are trying to convert
96+
a single GEP instruction into a vector representing the index element.
97+
This will require traversing backwards to find the initial values being used as indexes
98+
99+
TODO: some arrays are two dimensional but represented as a single index.
100+
We need to handle this case next.
101+
*/
102+
void getIndexValues(GetElementPtrInst *GEP, std::vector<NVPTXMemOpts::IndexType> &indexValues) {
103+
// get first index value. There should be exactly one
104+
auto index_value = GEP->idx_begin();
105+
if (isa<ConstantInt>(index_value)) {
106+
indexValues.push_back(NVPTXMemOpts::IndexType::CONSTANT);
107+
return;
108+
} else if (isAbsoluteThreadIndex(cast<Value>(index_value))) {
109+
indexValues.push_back(NVPTXMemOpts::IndexType::ABSOLUTE_THREAD_ID);
110+
return;
111+
}
112+
113+
114+
return;
115+
}
116+
117+
// Return dimension's indexes for an array load instruction
118+
// return 0 if the value is not an array
119+
std::vector<NVPTXMemOpts::IndexType> NVPTXMemOpts::isLoadingFromArray(LoadInst *LI) {
120+
121+
std::vector<NVPTXMemOpts::IndexType> indexValues;
122+
assert(LI && "LI is null");
123+
auto GEP = dyn_cast<GetElementPtrInst>(LI->getPointerOperand());
124+
if (!GEP) { return indexValues; }
125+
126+
auto ptr = GEP->getPointerOperand();
127+
auto ptrGEP = dyn_cast<GetElementPtrInst>(ptr);
128+
assert(!ptrGEP && "Nested GEP not supported");
129+
130+
// get index value. There should be exactly one
131+
auto idx = GEP->idx_begin();
132+
assert(idx != GEP->idx_end() && "No index found");
133+
134+
// TODO:: if more than one index, it is probably coalesced already
135+
// if (++idx != GEP->idx_end()) {
136+
// return indexValues;
137+
// }
138+
139+
getIndexValues(GEP, indexValues);
140+
return indexValues;
141+
}
142+
143+
int isStoringToArray(StoreInst *SI) {
144+
assert(SI && "SI is null");
145+
auto GEP = dyn_cast<GetElementPtrInst>(SI->getPointerOperand());
146+
if (!GEP) { return 0; }
147+
148+
return 0;
149+
}
150+
151+
// Check if the index is a constant
152+
bool isIndexConstant(Value *idx) {
153+
return isa<ConstantInt>(idx);
154+
}
155+
156+
// Check if the index is a thread constant.
157+
// ie. the thread id. this is not a constant for all threads in a warp
158+
bool isIndexThreadConstant(Value *idx) {
159+
return false;
160+
}
161+
162+
163+
bool NVPTXMemOpts::isCallCoalescable(LoadInst *LI, std::vector<IndexType> &indexValues) {
164+
// Check if the call is already coalesced
165+
// We can do this by seeing if the call is already a load from shared memory
166+
// If it is, we can skip this call
167+
auto GEP = dyn_cast<GetElementPtrInst>(LI->getPointerOperand());
168+
assert(GEP && "GEP is null");
169+
auto ptr = GEP->getPointerOperand();
170+
auto ptrGEP = dyn_cast<GetElementPtrInst>(ptr);
171+
assert(!ptrGEP && "Nested GEP not supported");
172+
173+
// check if the loaded float is being used by a store into shared memory (addressspace 3)
174+
// if it is, we can skip this call
175+
176+
// check if the gep is loading from global memory
177+
if (GEP->getPointerOperand()->getType()->getPointerAddressSpace() != 1) {
178+
return false;
179+
}
180+
181+
auto storeInst = dyn_cast<StoreInst>(LI->user_back());
182+
if (storeInst && storeInst->getPointerAddressSpace() == 3) {
183+
return false;
184+
}
185+
186+
// TODO:: otherwise, for now, we will assume that the call is coalescable
187+
return true;
188+
}
189+
190+
/*
191+
This function will coalesce memory calls.
192+
Example:
193+
194+
%0 = call noundef i32 @llvm.nvvm.read.ptx.sreg.ntid.x()
195+
%1 = call noundef i32 @llvm.nvvm.read.ptx.sreg.ctaid.x()
196+
%mul = mul i32 %0, %1
197+
%2 = call noundef i32 @llvm.nvvm.read.ptx.sreg.tid.x()
198+
%add = add i32 %mul, %2
199+
%3 = call noundef i32 @llvm.nvvm.read.ptx.sreg.ntid.y()
200+
%4 = call noundef i32 @llvm.nvvm.read.ptx.sreg.ctaid.y()
201+
%mul5 = mul i32 %3, %4
202+
%5 = call noundef i32 @llvm.nvvm.read.ptx.sreg.tid.y()
203+
%add7 = add i32 %mul5, %5
204+
%idxprom = sext i32 %add7 to i64
205+
%arrayidx = getelementptr inbounds ptr, ptr %A2, i64 %idxprom
206+
%6 = load ptr, ptr %arrayidx, align 8
207+
%idxprom8 = sext i32 %i.0 to i64
208+
%arrayidx9 = getelementptr inbounds float, ptr %6, i64 %idxprom8
209+
%7 = load float, ptr %arrayidx9, align 4
210+
211+
Will give the following parameters:
212+
LI = %6
213+
indexValues = { %add7, %add5, %add, %mul, %add7, %add, %mul }
214+
215+
Will be coalesced to:
216+
217+
218+
219+
220+
*/
221+
222+
/*
223+
Rules regarding coalescing:
224+
- if the index is a constant for all threads in a warp, it cannot be coalesced
225+
- if the index is a constant for one thread but contiguous across a warp, it can be coalesced
226+
- if the index is a loop induction variable, it can be coalesced
227+
228+
Other memory accesses will be ignored for now
229+
*/
230+
void NVPTXMemOpts::CoalesceMemCalls(LoadInst *LI, std::vector<IndexType> &indexValues) {
231+
assert (LI && "LI is null");
232+
assert (indexValues.size() > 0 && "indexValues is empty");
233+
auto GEP = dyn_cast<GetElementPtrInst>(LI->getPointerOperand());
234+
assert (GEP && "GEP is null");
235+
236+
// First, we need to create a shared memory buffer to store the data
237+
// We will use the same type as the original array
238+
auto arrayType = GEP->getSourceElementType();
239+
// TODO:: for now, we are assuming type is int64 or float64;
240+
// we need 8 bytes per element, 64 bytes per warp
241+
int arraySize = 16;
242+
// Create the array type
243+
auto sharedArrayType = ArrayType::get(arrayType, arraySize);
244+
auto arrayInitVal = UndefValue::get(sharedArrayType); // TODO:: see why this is not working as the array initializer below
245+
auto sharedArray = new GlobalVariable(*M, sharedArrayType,
246+
false, GlobalValue::InternalLinkage,
247+
arrayInitVal, "sharedArray", nullptr,
248+
GlobalValue::NotThreadLocal, 3, false);
249+
sharedArray->setAlignment(MaybeAlign(4));
250+
251+
IRBuilder<> Builder(GEP->getNextNode());
252+
// First, we need to load the value from the original array.
253+
// This will be loaded into shared memory.
254+
auto LoadInst = Builder.CreateLoad(GEP->getSourceElementType(), GEP);
255+
// Next, we need to calculate the index for the shared memory array
256+
// The original index is the absolute thread id. we need to convert this to tid
257+
// first, get the thread id. Find the instrinsic call that is already in the function
258+
auto TidInstrinsic = Intrinsic::getDeclaration(M, Intrinsic::nvvm_read_ptx_sreg_tid_x);
259+
// get the register that reads the thread id
260+
auto TidVal = Builder.CreateCall(TidInstrinsic, {});
261+
262+
// Next, we need to calculate the index for the shared memory array
263+
// first, zero extend the tid value
264+
auto TidZeroExt = Builder.CreateZExt(TidVal, Type::getInt64Ty(M->getContext()));
265+
// Next, create a GEP to calculate the index of shared memory
266+
auto ZeroVal = ConstantInt::get(Type::getInt64Ty(M->getContext()), 0);
267+
auto SharedGEP = Builder.CreateGEP(sharedArrayType, sharedArray, std::vector<Value*>{ZeroVal, TidZeroExt});
268+
269+
// store the value from the original array to the shared memory array
270+
Builder.CreateStore(LoadInst, SharedGEP);
271+
272+
// We need to insert __syncthreads() before the load instruction
273+
// This is to ensure that all threads have written to shared memory before we read from it
274+
auto syncThreads = Intrinsic::getDeclaration(M, Intrinsic::nvvm_barrier0);
275+
Builder.CreateCall(syncThreads, {});
276+
277+
// Finally, replace the load location with the shared memory location
278+
Builder.SetInsertPoint(LI);
279+
auto SharedLoad = Builder.CreateLoad(GEP->getSourceElementType(), SharedGEP);
280+
LI->replaceAllUsesWith(SharedLoad);
281+
// LI->eraseFromParent();
282+
283+
}
47284

48285
bool NVPTXMemOpts::runOnFunction(Function &F) {
286+
M = F.getParent();
287+
288+
errs() << "Hello from NVPTXMemOpts\n";
289+
std::vector<LoadInst*> toDelete;
290+
for (auto &BB : F) {
291+
for (auto I = BB.begin(); I != BB.end(); ++I){
292+
if (auto *LI = dyn_cast<LoadInst>(&*I)) {
293+
auto indexValues = isLoadingFromArray(LI);
294+
if (indexValues.empty())
295+
continue;
296+
if (isCallCoalescable(LI, indexValues)) {
297+
errs() << "Found a candidate instruction: " << *LI << "\n";
298+
CoalesceMemCalls(LI, indexValues);
299+
toDelete.push_back(LI);
300+
}
301+
}
302+
if (auto *SI = dyn_cast<StoreInst>(&*I)) {
303+
if (isStoringToArray(SI) > 0) {
304+
errs() << "Found a store instruction: " << *SI << "\n";
305+
}
306+
}
307+
}
308+
for (auto LI : toDelete) {
309+
// asser that the LI has no uses
310+
assert(LI->use_empty());
311+
LI->eraseFromParent();
312+
}
313+
toDelete.clear();
314+
}
49315
return false;
50316
}
51317

52-
} // end anonymous namespace
318+
// } // end anonymous namespace
53319

54320
namespace llvm {
55321
void initializeNVPTXMemOptsPass(PassRegistry &);

0 commit comments

Comments
 (0)