2121#include " llvm/Pass.h"
2222#include " llvm/Support/Debug.h"
2323#include " llvm/Transforms/Utils/BasicBlockUtils.h"
24+ #include " llvm/IR/IntrinsicsNVPTX.h"
2425
2526#define DEBUG_TYPE " nvptx-mem-opts"
2627
@@ -41,15 +42,215 @@ namespace {
4142 StringRef getPassName () const override {
4243 return " Memory coalescing and prefetching" ;
4344 }
45+
46+ static std::string SYNC_THREADS_INTRINSIC_NAME;
47+ static std::string NVVM_READ_SREG_INTRINSIC_NAME;
48+
49+ enum IndexType {
50+ CONSTANT,
51+ ABSOLUTE_THREAD_ID,
52+ LOOP_INDUCTION
53+ };
54+ private:
55+
56+ // Helper functions
57+ void CoalesceMemCalls (LoadInst *LI, std::vector<IndexType> &indexValues);
58+ bool isCallCoalescable (LoadInst *LI, std::vector<IndexType> &indexValues);
59+
60+ std::vector<IndexType> isLoadingFromArray (LoadInst *LI);
61+
62+ Module *M;
63+ };
4464 };
4565
4666char NVPTXMemOpts::ID = 0 ;
67+ std::string NVPTXMemOpts::SYNC_THREADS_INTRINSIC_NAME = " llvm.nvvm.barrier0" ;
68+ std::string NVPTXMemOpts::NVVM_READ_SREG_INTRINSIC_NAME = " llvm.nvvm.read.ptx.sreg" ;
69+
70+ // A common pattern to calculate the abosolute index of a thread is:
71+ // idx = tid + ctaid * ntid
72+ // This function will check if an index is calculated in this way
73+ bool isAbsoluteThreadIndex (Value *idx) {
74+ auto sext = dyn_cast<SExtInst>(idx);
75+ if (!sext) { return false ; }
76+
77+ auto val = sext->getOperand (0 );
78+ auto add = dyn_cast<BinaryOperator>(val);
79+
80+ if (!add || add->getOpcode () != Instruction::Add) { return false ; }
81+
82+ auto mul = dyn_cast<BinaryOperator>(add->getOperand (0 ));
83+ if (!mul || mul->getOpcode () != Instruction::Mul) { return false ; }
84+
85+ auto tid = dyn_cast<CallInst>(mul->getOperand (0 ));
86+ auto ntid = dyn_cast<CallInst>(mul->getOperand (1 ));
87+ auto ctaid = dyn_cast<CallInst>(add->getOperand (1 ));
88+
89+ if (!tid || !ntid || !ctaid) { return false ; }
90+
91+ return true ;
92+ }
93+
94+ /*
95+ This function is quite complicated because we are trying to convert
96+ a single GEP instruction into a vector representing the index element.
97+ This will require traversing backwards to find the initial values being used as indexes
98+
99+ TODO: some arrays are two dimensional but represented as a single index.
100+ We need to handle this case next.
101+ */
102+ void getIndexValues (GetElementPtrInst *GEP, std::vector<NVPTXMemOpts::IndexType> &indexValues) {
103+ // get first index value. There should be exactly one
104+ auto index_value = GEP->idx_begin ();
105+ if (isa<ConstantInt>(index_value)) {
106+ indexValues.push_back (NVPTXMemOpts::IndexType::CONSTANT);
107+ return ;
108+ } else if (isAbsoluteThreadIndex (cast<Value>(index_value))) {
109+ indexValues.push_back (NVPTXMemOpts::IndexType::ABSOLUTE_THREAD_ID);
110+ return ;
111+ }
112+
113+
114+ return ;
115+ }
116+
117+ // This function will check if the load instruction is loading from an array
118+ // If it is, it will return the index value types used to access the array
119+ // If not, it will return an empty vector
120+ std::vector<NVPTXMemOpts::IndexType> NVPTXMemOpts::isLoadingFromArray (LoadInst *LI) {
121+
122+ std::vector<NVPTXMemOpts::IndexType> indexValues;
123+ assert (LI && " LI is null" );
124+ auto GEP = dyn_cast<GetElementPtrInst>(LI->getPointerOperand ());
125+ if (!GEP) { return indexValues; }
126+
127+ auto ptr = GEP->getPointerOperand ();
128+ auto ptrGEP = dyn_cast<GetElementPtrInst>(ptr);
129+ assert (!ptrGEP && " Nested GEP not supported" );
130+
131+ // get index value. There should be exactly one
132+ auto idx = GEP->idx_begin ();
133+ assert (idx != GEP->idx_end () && " No index found" );
134+
135+ getIndexValues (GEP, indexValues);
136+ return indexValues;
137+ }
138+
139+ /*
140+ Rules regarding coalescing:
141+ - if the index is a constant for all threads in a warp, it cannot be coalesced
142+ - if the index is a constant for one thread but contiguous across a warp, it can be coalesced
143+ - if the index is a loop induction variable, it can be coalesced
144+
145+ Other memory accesses will be ignored for now
146+ */
147+ bool NVPTXMemOpts::isCallCoalescable (LoadInst *LI, std::vector<IndexType> &indexValues) {
148+ auto GEP = dyn_cast<GetElementPtrInst>(LI->getPointerOperand ());
149+ assert (GEP && " GEP is null" );
150+ auto ptr = GEP->getPointerOperand ();
151+ auto ptrGEP = dyn_cast<GetElementPtrInst>(ptr);
152+ assert (!ptrGEP && " Nested GEP not supported" );
153+
154+ // We only consider loads from global memory. Filters out already coalesced loads
155+ if (GEP->getPointerOperand ()->getType ()->getPointerAddressSpace () != 1 ) {
156+ return false ;
157+ }
158+
159+ // If the load is being stored to shared memory, it cannot be coalesced
160+ // It is probably already coalesced
161+ auto storeInst = dyn_cast<StoreInst>(LI->user_back ());
162+ if (storeInst && storeInst->getPointerAddressSpace () == 3 ) {
163+ return false ;
164+ }
165+
166+ // TODO:: there will be other considerations
167+ // otherwise, we assume the call is coalescable
168+ return true ;
169+ }
170+
171+ void NVPTXMemOpts::CoalesceMemCalls (LoadInst *LI, std::vector<IndexType> &indexValues) {
172+ assert (LI && " LI is null" );
173+ assert (indexValues.size () > 0 && " indexValues is empty" );
174+ auto GEP = dyn_cast<GetElementPtrInst>(LI->getPointerOperand ());
175+ assert (GEP && " GEP is null" );
176+
177+ // First, we need to create a shared memory buffer to store the data
178+ // We will use the same type as the original array
179+ auto arrayType = GEP->getSourceElementType ();
180+ // TODO:: for now, we are assuming type is int64 or float64;
181+ // we need 8 bytes per element, 64 bytes per warp
182+ int arraySize = 16 ;
183+ // Create the array type
184+ auto sharedArrayType = ArrayType::get (arrayType, arraySize);
185+ auto arrayInitVal = UndefValue::get (sharedArrayType); // TODO:: see why this is not working as the array initializer below
186+ auto sharedArray = new GlobalVariable (*M, sharedArrayType,
187+ false , GlobalValue::InternalLinkage,
188+ arrayInitVal, " sharedArray" , nullptr ,
189+ GlobalValue::NotThreadLocal, 3 , false );
190+ sharedArray->setAlignment (MaybeAlign (4 ));
191+
192+ IRBuilder<> Builder (GEP->getNextNode ());
193+ // First, we need to load the value from the original array.
194+ // This will be loaded into shared memory.
195+ auto LoadInst = Builder.CreateLoad (GEP->getSourceElementType (), GEP);
196+ // Next, we need to calculate the index for the shared memory array
197+ // The original index is the absolute thread id. we need to convert this to tid
198+ // first, get the thread id. Find the instrinsic call that is already in the function
199+ auto TidInstrinsic = Intrinsic::getDeclaration (M, Intrinsic::nvvm_read_ptx_sreg_tid_x);
200+ // get the register that reads the thread id
201+ auto TidVal = Builder.CreateCall (TidInstrinsic, {});
202+
203+ // Next, we need to calculate the index for the shared memory array
204+ // first, zero extend the tid value
205+ auto TidZeroExt = Builder.CreateZExt (TidVal, Type::getInt64Ty (M->getContext ()));
206+ // Next, create a GEP to calculate the index of shared memory
207+ auto ZeroVal = ConstantInt::get (Type::getInt64Ty (M->getContext ()), 0 );
208+ auto SharedGEP = Builder.CreateGEP (sharedArrayType, sharedArray, std::vector<Value*>{ZeroVal, TidZeroExt});
209+
210+ // store the value from the original array to the shared memory array
211+ Builder.CreateStore (LoadInst, SharedGEP);
212+
213+ // We need to insert __syncthreads() before the load instruction
214+ // This is to ensure that all threads have written to shared memory before we read from it
215+ auto syncThreads = Intrinsic::getDeclaration (M, Intrinsic::nvvm_barrier0);
216+ Builder.CreateCall (syncThreads, {});
217+
218+ // Finally, replace the load location with the shared memory location
219+ Builder.SetInsertPoint (LI);
220+ auto SharedLoad = Builder.CreateLoad (GEP->getSourceElementType (), SharedGEP);
221+ LI->replaceAllUsesWith (SharedLoad);
222+
223+ }
47224
48225bool NVPTXMemOpts::runOnFunction (Function &F) {
226+ M = F.getParent ();
227+
228+ errs () << " Hello from NVPTXMemOpts\n " ;
229+ std::vector<LoadInst*> toDelete;
230+ for (auto &BB : F) {
231+ for (auto I = BB.begin (); I != BB.end (); ++I){
232+ if (auto *LI = dyn_cast<LoadInst>(&*I)) {
233+ auto indexValues = isLoadingFromArray (LI);
234+ if (indexValues.empty ())
235+ continue ;
236+ if (isCallCoalescable (LI, indexValues)) {
237+ errs () << " Found a candidate instruction: " << *LI << " \n " ;
238+ CoalesceMemCalls (LI, indexValues);
239+ toDelete.push_back (LI);
240+ }
241+ }
242+ }
243+ for (auto LI : toDelete) {
244+ // assert that the LI has no uses
245+ assert (LI->use_empty ());
246+ LI->eraseFromParent ();
247+ }
248+ toDelete.clear ();
249+ }
49250 return false ;
50251}
51252
52- } // end anonymous namespace
253+ // } // end anonymous namespace
53254
54255namespace llvm {
55256void initializeNVPTXMemOptsPass (PassRegistry &);
0 commit comments