@@ -60,7 +60,7 @@ struct TransformParamsToConstRefContext
60
60
case kIROp_FieldExtract :
61
61
{
62
62
// Transform the IRFieldExtract into a IRFieldAddress
63
- auto fieldExtract = as<IRFieldExtract>(use-> getUser () );
63
+ auto fieldExtract = as<IRFieldExtract>(user );
64
64
builder.setInsertBefore (fieldExtract);
65
65
auto fieldAddr = builder.emitFieldAddress (
66
66
fieldExtract->getBase (),
@@ -73,8 +73,7 @@ struct TransformParamsToConstRefContext
73
73
case kIROp_GetElement :
74
74
{
75
75
// Transform the IRGetElement into a IRGetElementPtr
76
- auto getElement = as<IRGetElement>(use->getUser ());
77
-
76
+ auto getElement = as<IRGetElement>(user);
78
77
builder.setInsertBefore (getElement);
79
78
auto elemAddr = builder.emitElementAddress (
80
79
getElement->getBase (),
@@ -111,14 +110,8 @@ struct TransformParamsToConstRefContext
111
110
List<IRInst*> newArgs;
112
111
113
112
// Transform arguments to match the updated-parameter
114
- IRParam* param = func->getFirstParam ();
115
113
UInt i = 0 ;
116
- auto iterate = [&]()
117
- {
118
- param = param->getNextParam ();
119
- i++;
120
- };
121
- for (; param; iterate())
114
+ for (IRParam* param = func->getFirstParam (); param; param = param->getNextParam (), i++)
122
115
{
123
116
auto arg = call->getArg (i);
124
117
if (!updatedParams.contains (param))
@@ -179,11 +172,191 @@ struct TransformParamsToConstRefContext
179
172
return true ;
180
173
}
181
174
175
+ // Eliminate unnecessary load+store pairs that can be optimized away
176
+ // Only optimize patterns related to the updated parameters
177
+ void eliminateLoadStorePairs (IRFunc* func, HashSet<IRParam*>& updatedParams)
178
+ {
179
+ List<IRInst*> toRemove;
180
+
181
+ // Look for patterns: load(ptr) followed by store(var, load_result)
182
+ // This can be optimized to direct pointer usage
183
+ for (auto block : func->getBlocks ())
184
+ {
185
+ for (auto inst : block->getChildren ())
186
+ {
187
+ if (auto storeInst = as<IRStore>(inst))
188
+ {
189
+ auto storedValue = storeInst->getVal ();
190
+ auto destPtr = storeInst->getPtr ();
191
+
192
+ // Check if we're storing a load result
193
+ if (auto loadInst = as<IRLoad>(storedValue))
194
+ {
195
+ auto loadedPtr = loadInst->getPtr ();
196
+
197
+ // Only optimize if the loaded pointer is related to our updated parameters
198
+ bool isRelatedToUpdatedParam = false ;
199
+ for (auto param : updatedParams)
200
+ {
201
+ if (loadedPtr == param)
202
+ {
203
+ isRelatedToUpdatedParam = true ;
204
+ break ;
205
+ }
206
+ }
207
+
208
+ if (!isRelatedToUpdatedParam)
209
+ continue ;
210
+
211
+ // IMPORTANT: Only optimize if destPtr is a variable (kIROp_Var)
212
+ // Don't optimize stores to buffer elements or other meaningful destinations
213
+ if (!as<IRVar>(destPtr))
214
+ continue ;
215
+
216
+ // Check if this load has only one use (this store)
217
+ bool loadHasOnlyOneUse = true ;
218
+ UInt useCount = 0 ;
219
+ for (auto use = loadInst->firstUse ; use; use = use->nextUse )
220
+ {
221
+ useCount++;
222
+ if (useCount > 1 || use->getUser () != storeInst)
223
+ {
224
+ loadHasOnlyOneUse = false ;
225
+ break ;
226
+ }
227
+ }
228
+
229
+ if (loadHasOnlyOneUse)
230
+ {
231
+ // Replace all uses of destPtr with loadedPtr
232
+ destPtr->replaceUsesWith (loadedPtr);
233
+
234
+ // Mark both instructions for removal
235
+ toRemove.add (storeInst);
236
+ toRemove.add (loadInst);
237
+
238
+ // Also remove the variable if it's only used by this store
239
+ if (auto varInst = as<IRVar>(destPtr))
240
+ {
241
+ bool varHasOnlyStoreUse = true ;
242
+ for (auto use = varInst->firstUse ; use; use = use->nextUse )
243
+ {
244
+ if (use->getUser () != storeInst)
245
+ {
246
+ varHasOnlyStoreUse = false ;
247
+ break ;
248
+ }
249
+ }
250
+ if (varHasOnlyStoreUse)
251
+ {
252
+ toRemove.add (varInst);
253
+ }
254
+ }
255
+
256
+ changed = true ;
257
+ }
258
+ }
259
+ }
260
+ }
261
+ }
262
+
263
+ // Remove marked instructions
264
+ for (auto inst : toRemove)
265
+ {
266
+ if (inst->getParent ())
267
+ {
268
+ inst->removeAndDeallocate ();
269
+ }
270
+ }
271
+ }
272
+
273
+ // Eliminate load+store pairs for entry point functions (without specific updatedParams)
274
+ void eliminateLoadStorePairsForEntryPoint (IRFunc* func)
275
+ {
276
+ List<IRInst*> toRemove;
277
+
278
+ // Look for patterns: load(ptr) followed by store(var, load_result)
279
+ // This can be optimized to direct pointer usage
280
+ for (auto block : func->getBlocks ())
281
+ {
282
+ for (auto inst : block->getChildren ())
283
+ {
284
+ if (auto storeInst = as<IRStore>(inst))
285
+ {
286
+ auto storedValue = storeInst->getVal ();
287
+ auto destPtr = storeInst->getPtr ();
288
+
289
+ // Check if we're storing a load result
290
+ if (auto loadInst = as<IRLoad>(storedValue))
291
+ {
292
+ auto loadedPtr = loadInst->getPtr ();
293
+
294
+ // IMPORTANT: Only optimize if destPtr is a variable (kIROp_Var)
295
+ // Don't optimize stores to buffer elements or other meaningful destinations
296
+ if (!as<IRVar>(destPtr))
297
+ continue ;
298
+
299
+ // Check if this load has only one use (this store)
300
+ bool loadHasOnlyOneUse = true ;
301
+ UInt useCount = 0 ;
302
+ for (auto use = loadInst->firstUse ; use; use = use->nextUse )
303
+ {
304
+ useCount++;
305
+ if (useCount > 1 || use->getUser () != storeInst)
306
+ {
307
+ loadHasOnlyOneUse = false ;
308
+ break ;
309
+ }
310
+ }
311
+
312
+ if (loadHasOnlyOneUse)
313
+ {
314
+ // Replace all uses of destPtr with loadedPtr
315
+ destPtr->replaceUsesWith (loadedPtr);
316
+
317
+ // Mark both instructions for removal
318
+ toRemove.add (storeInst);
319
+ toRemove.add (loadInst);
320
+
321
+ // Also remove the variable if it's only used by this store
322
+ if (auto varInst = as<IRVar>(destPtr))
323
+ {
324
+ bool varHasOnlyStoreUse = true ;
325
+ for (auto use = varInst->firstUse ; use; use = use->nextUse )
326
+ {
327
+ if (use->getUser () != storeInst)
328
+ {
329
+ varHasOnlyStoreUse = false ;
330
+ break ;
331
+ }
332
+ }
333
+ if (varHasOnlyStoreUse)
334
+ {
335
+ toRemove.add (varInst);
336
+ }
337
+ }
338
+
339
+ changed = true ;
340
+ }
341
+ }
342
+ }
343
+ }
344
+ }
345
+
346
+ // Remove marked instructions
347
+ for (auto inst : toRemove)
348
+ {
349
+ if (inst->getParent ())
350
+ {
351
+ inst->removeAndDeallocate ();
352
+ }
353
+ }
354
+ }
355
+
182
356
// Process a single function
183
357
void processFunc (IRFunc* func)
184
358
{
185
359
HashSet<IRParam*> updatedParams;
186
- bool hasTransformedParams = false ;
187
360
188
361
// First pass: Transform parameter types
189
362
for (auto param = func->getFirstParam (); param; param = param->getNextParam ())
@@ -203,13 +376,12 @@ struct TransformParamsToConstRefContext
203
376
auto constRefType = builder.getConstRefType (paramType, AddressSpace::ThreadLocal);
204
377
param->setFullType (constRefType);
205
378
206
- hasTransformedParams = true ;
207
379
changed = true ;
208
380
updatedParams.add (param);
209
381
}
210
382
}
211
383
212
- if (!hasTransformedParams )
384
+ if (updatedParams. getCount () == 0 )
213
385
{
214
386
return ;
215
387
}
@@ -219,6 +391,9 @@ struct TransformParamsToConstRefContext
219
391
220
392
// Third pass: Update call sites
221
393
updateCallSites (func, updatedParams);
394
+
395
+ // Optimization pass: Remove unnecessary load+store pairs created by updateCallSites
396
+ eliminateLoadStorePairs (func, updatedParams);
222
397
}
223
398
224
399
void addFuncsToCallListInTopologicalOrder (
@@ -274,6 +449,22 @@ struct TransformParamsToConstRefContext
274
449
processFunc (func);
275
450
}
276
451
452
+ // Handle entry point functions separately - they don't get processed by processFunc
453
+ // but they still need load/store optimization for parameters that come from global parameters
454
+ for (auto inst = module ->getModuleInst ()->getFirstChild (); inst; inst = inst->getNextInst ())
455
+ {
456
+ auto func = as<IRFunc>(inst);
457
+ if (!func || !func->isDefinition ())
458
+ continue ;
459
+
460
+ // Only process entry point functions that weren't already processed
461
+ if (!shouldProcessFunction (func))
462
+ {
463
+ // For entry point functions, use a broader optimization since we don't have specific updatedParams
464
+ eliminateLoadStorePairsForEntryPoint (func);
465
+ }
466
+ }
467
+
277
468
return SLANG_OK;
278
469
}
279
470
};
0 commit comments