2121#include " llvm/Pass.h"
2222#include " llvm/Support/Debug.h"
2323#include " llvm/Transforms/Utils/BasicBlockUtils.h"
24+ #include " llvm/IR/IntrinsicsNVPTX.h"
2425
2526#define DEBUG_TYPE " nvptx-mem-opts"
2627
@@ -41,15 +42,280 @@ namespace {
4142 StringRef getPassName () const override {
4243 return " Memory coalescing and prefetching" ;
4344 }
45+
46+ static std::string SYNC_THREADS_INTRINSIC_NAME;
47+ static std::string NVVM_READ_SREG_INTRINSIC_NAME;
48+
49+ enum IndexType {
50+ CONSTANT,
51+ ABSOLUTE_THREAD_ID,
52+ LOOP_INDUCTION
53+ };
54+ private:
55+
56+ // Helper functions
57+ void CoalesceMemCalls (LoadInst *LI, std::vector<IndexType> &indexValues);
58+ bool isCallCoalescable (LoadInst *LI, std::vector<IndexType> &indexValues);
59+
60+ std::vector<IndexType> isLoadingFromArray (LoadInst *LI);
61+
62+ Module *M;
63+ };
4464 };
4565
4666char NVPTXMemOpts::ID = 0 ;
67+ std::string NVPTXMemOpts::SYNC_THREADS_INTRINSIC_NAME = " llvm.nvvm.barrier0" ;
68+ std::string NVPTXMemOpts::NVVM_READ_SREG_INTRINSIC_NAME = " llvm.nvvm.read.ptx.sreg" ;
69+
70+ // A common pattern to calculate the abosolute index of a thread is:
71+ // idx = tid + ctaid * ntid
72+ // This function will check if the index is calculated in this way
73+ bool isAbsoluteThreadIndex (Value *idx) {
74+ auto sext = dyn_cast<SExtInst>(idx);
75+ if (!sext) { return false ; }
76+
77+ auto val = sext->getOperand (0 );
78+ auto add = dyn_cast<BinaryOperator>(val);
79+
80+ if (!add || add->getOpcode () != Instruction::Add) { return false ; }
81+
82+ auto mul = dyn_cast<BinaryOperator>(add->getOperand (0 ));
83+ if (!mul || mul->getOpcode () != Instruction::Mul) { return false ; }
84+
85+ auto tid = dyn_cast<CallInst>(mul->getOperand (0 ));
86+ auto ntid = dyn_cast<CallInst>(mul->getOperand (1 ));
87+ auto ctaid = dyn_cast<CallInst>(add->getOperand (1 ));
88+
89+ if (!tid || !ntid || !ctaid) { return false ; }
90+
91+ return true ;
92+ }
93+
94+ /*
95+ This function is quite complicated because we are trying to convert
96+ a single GEP instruction into a vector representing the index element.
97+ This will require traversing backwards to find the initial values being used as indexes
98+
99+ TODO: some arrays are two dimensional but represented as a single index.
100+ We need to handle this case next.
101+ */
102+ void getIndexValues (GetElementPtrInst *GEP, std::vector<NVPTXMemOpts::IndexType> &indexValues) {
103+ // get first index value. There should be exactly one
104+ auto index_value = GEP->idx_begin ();
105+ if (isa<ConstantInt>(index_value)) {
106+ indexValues.push_back (NVPTXMemOpts::IndexType::CONSTANT);
107+ return ;
108+ } else if (isAbsoluteThreadIndex (cast<Value>(index_value))) {
109+ indexValues.push_back (NVPTXMemOpts::IndexType::ABSOLUTE_THREAD_ID);
110+ return ;
111+ }
112+
113+
114+ return ;
115+ }
116+
117+ // Return dimension's indexes for an array load instruction
118+ // return 0 if the value is not an array
119+ std::vector<NVPTXMemOpts::IndexType> NVPTXMemOpts::isLoadingFromArray (LoadInst *LI) {
120+
121+ std::vector<NVPTXMemOpts::IndexType> indexValues;
122+ assert (LI && " LI is null" );
123+ auto GEP = dyn_cast<GetElementPtrInst>(LI->getPointerOperand ());
124+ if (!GEP) { return indexValues; }
125+
126+ auto ptr = GEP->getPointerOperand ();
127+ auto ptrGEP = dyn_cast<GetElementPtrInst>(ptr);
128+ assert (!ptrGEP && " Nested GEP not supported" );
129+
130+ // get index value. There should be exactly one
131+ auto idx = GEP->idx_begin ();
132+ assert (idx != GEP->idx_end () && " No index found" );
133+
134+ // TODO:: if more than one index, it is probably coalesced already
135+ // if (++idx != GEP->idx_end()) {
136+ // return indexValues;
137+ // }
138+
139+ getIndexValues (GEP, indexValues);
140+ return indexValues;
141+ }
142+
143+ int isStoringToArray (StoreInst *SI) {
144+ assert (SI && " SI is null" );
145+ auto GEP = dyn_cast<GetElementPtrInst>(SI->getPointerOperand ());
146+ if (!GEP) { return 0 ; }
147+
148+ return 0 ;
149+ }
150+
151+ // Check if the index is a constant
152+ bool isIndexConstant (Value *idx) {
153+ return isa<ConstantInt>(idx);
154+ }
155+
156+ // Check if the index is a thread constant.
157+ // ie. the thread id. this is not a constant for all threads in a warp
158+ bool isIndexThreadConstant (Value *idx) {
159+ return false ;
160+ }
161+
162+
163+ bool NVPTXMemOpts::isCallCoalescable (LoadInst *LI, std::vector<IndexType> &indexValues) {
164+ // Check if the call is already coalesced
165+ // We can do this by seeing if the call is already a load from shared memory
166+ // If it is, we can skip this call
167+ auto GEP = dyn_cast<GetElementPtrInst>(LI->getPointerOperand ());
168+ assert (GEP && " GEP is null" );
169+ auto ptr = GEP->getPointerOperand ();
170+ auto ptrGEP = dyn_cast<GetElementPtrInst>(ptr);
171+ assert (!ptrGEP && " Nested GEP not supported" );
172+
173+ // check if the loaded float is being used by a store into shared memory (addressspace 3)
174+ // if it is, we can skip this call
175+
176+ // check if the gep is loading from global memory
177+ if (GEP->getPointerOperand ()->getType ()->getPointerAddressSpace () != 1 ) {
178+ return false ;
179+ }
180+
181+ auto storeInst = dyn_cast<StoreInst>(LI->user_back ());
182+ if (storeInst && storeInst->getPointerAddressSpace () == 3 ) {
183+ return false ;
184+ }
185+
186+ // TODO:: otherwise, for now, we will assume that the call is coalescable
187+ return true ;
188+ }
189+
190+ /*
191+ This function will coalesce memory calls.
192+ Example:
193+
194+ %0 = call noundef i32 @llvm.nvvm.read.ptx.sreg.ntid.x()
195+ %1 = call noundef i32 @llvm.nvvm.read.ptx.sreg.ctaid.x()
196+ %mul = mul i32 %0, %1
197+ %2 = call noundef i32 @llvm.nvvm.read.ptx.sreg.tid.x()
198+ %add = add i32 %mul, %2
199+ %3 = call noundef i32 @llvm.nvvm.read.ptx.sreg.ntid.y()
200+ %4 = call noundef i32 @llvm.nvvm.read.ptx.sreg.ctaid.y()
201+ %mul5 = mul i32 %3, %4
202+ %5 = call noundef i32 @llvm.nvvm.read.ptx.sreg.tid.y()
203+ %add7 = add i32 %mul5, %5
204+ %idxprom = sext i32 %add7 to i64
205+ %arrayidx = getelementptr inbounds ptr, ptr %A2, i64 %idxprom
206+ %6 = load ptr, ptr %arrayidx, align 8
207+ %idxprom8 = sext i32 %i.0 to i64
208+ %arrayidx9 = getelementptr inbounds float, ptr %6, i64 %idxprom8
209+ %7 = load float, ptr %arrayidx9, align 4
210+
211+ Will give the following parameters:
212+ LI = %6
213+ indexValues = { %add7, %add5, %add, %mul, %add7, %add, %mul }
214+
215+ Will be coalesced to:
216+
217+
218+
219+
220+ */
221+
222+ /*
223+ Rules regarding coalescing:
224+ - if the index is a constant for all threads in a warp, it cannot be coalesced
225+ - if the index is a constant for one thread but contiguous across a warp, it can be coalesced
226+ - if the index is a loop induction variable, it can be coalesced
227+
228+ Other memory accesses will be ignored for now
229+ */
230+ void NVPTXMemOpts::CoalesceMemCalls (LoadInst *LI, std::vector<IndexType> &indexValues) {
231+ assert (LI && " LI is null" );
232+ assert (indexValues.size () > 0 && " indexValues is empty" );
233+ auto GEP = dyn_cast<GetElementPtrInst>(LI->getPointerOperand ());
234+ assert (GEP && " GEP is null" );
235+
236+ // First, we need to create a shared memory buffer to store the data
237+ // We will use the same type as the original array
238+ auto arrayType = GEP->getSourceElementType ();
239+ // TODO:: for now, we are assuming type is int64 or float64;
240+ // we need 8 bytes per element, 64 bytes per warp
241+ int arraySize = 16 ;
242+ // Create the array type
243+ auto sharedArrayType = ArrayType::get (arrayType, arraySize);
244+ auto arrayInitVal = UndefValue::get (sharedArrayType); // TODO:: see why this is not working as the array initializer below
245+ auto sharedArray = new GlobalVariable (*M, sharedArrayType,
246+ false , GlobalValue::InternalLinkage,
247+ arrayInitVal, " sharedArray" , nullptr ,
248+ GlobalValue::NotThreadLocal, 3 , false );
249+ sharedArray->setAlignment (MaybeAlign (4 ));
250+
251+ IRBuilder<> Builder (GEP->getNextNode ());
252+ // First, we need to load the value from the original array.
253+ // This will be loaded into shared memory.
254+ auto LoadInst = Builder.CreateLoad (GEP->getSourceElementType (), GEP);
255+ // Next, we need to calculate the index for the shared memory array
256+ // The original index is the absolute thread id. we need to convert this to tid
257+ // first, get the thread id. Find the instrinsic call that is already in the function
258+ auto TidInstrinsic = Intrinsic::getDeclaration (M, Intrinsic::nvvm_read_ptx_sreg_tid_x);
259+ // get the register that reads the thread id
260+ auto TidVal = Builder.CreateCall (TidInstrinsic, {});
261+
262+ // Next, we need to calculate the index for the shared memory array
263+ // first, zero extend the tid value
264+ auto TidZeroExt = Builder.CreateZExt (TidVal, Type::getInt64Ty (M->getContext ()));
265+ // Next, create a GEP to calculate the index of shared memory
266+ auto ZeroVal = ConstantInt::get (Type::getInt64Ty (M->getContext ()), 0 );
267+ auto SharedGEP = Builder.CreateGEP (sharedArrayType, sharedArray, std::vector<Value*>{ZeroVal, TidZeroExt});
268+
269+ // store the value from the original array to the shared memory array
270+ Builder.CreateStore (LoadInst, SharedGEP);
271+
272+ // We need to insert __syncthreads() before the load instruction
273+ // This is to ensure that all threads have written to shared memory before we read from it
274+ auto syncThreads = Intrinsic::getDeclaration (M, Intrinsic::nvvm_barrier0);
275+ Builder.CreateCall (syncThreads, {});
276+
277+ // Finally, replace the load location with the shared memory location
278+ Builder.SetInsertPoint (LI);
279+ auto SharedLoad = Builder.CreateLoad (GEP->getSourceElementType (), SharedGEP);
280+ LI->replaceAllUsesWith (SharedLoad);
281+ // LI->eraseFromParent();
282+
283+ }
47284
48285bool NVPTXMemOpts::runOnFunction (Function &F) {
286+ M = F.getParent ();
287+
288+ errs () << " Hello from NVPTXMemOpts\n " ;
289+ std::vector<LoadInst*> toDelete;
290+ for (auto &BB : F) {
291+ for (auto I = BB.begin (); I != BB.end (); ++I){
292+ if (auto *LI = dyn_cast<LoadInst>(&*I)) {
293+ auto indexValues = isLoadingFromArray (LI);
294+ if (indexValues.empty ())
295+ continue ;
296+ if (isCallCoalescable (LI, indexValues)) {
297+ errs () << " Found a candidate instruction: " << *LI << " \n " ;
298+ CoalesceMemCalls (LI, indexValues);
299+ toDelete.push_back (LI);
300+ }
301+ }
302+ if (auto *SI = dyn_cast<StoreInst>(&*I)) {
303+ if (isStoringToArray (SI) > 0 ) {
304+ errs () << " Found a store instruction: " << *SI << " \n " ;
305+ }
306+ }
307+ }
308+ for (auto LI : toDelete) {
309+ // asser that the LI has no uses
310+ assert (LI->use_empty ());
311+ LI->eraseFromParent ();
312+ }
313+ toDelete.clear ();
314+ }
49315 return false ;
50316}
51317
52- } // end anonymous namespace
318+ // } // end anonymous namespace
53319
54320namespace llvm {
55321void initializeNVPTXMemOptsPass (PassRegistry &);
0 commit comments