Skip to content

Commit bbb42ed

Browse files
committed
Remove unnecessary Load and Store pair
1 parent 3745d75 commit bbb42ed

File tree

4 files changed

+262
-21
lines changed

4 files changed

+262
-21
lines changed

source/slang/slang-ir-transform-params-to-constref.cpp

Lines changed: 204 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ struct TransformParamsToConstRefContext
6060
case kIROp_FieldExtract:
6161
{
6262
// Transform the IRFieldExtract into a IRFieldAddress
63-
auto fieldExtract = as<IRFieldExtract>(use->getUser());
63+
auto fieldExtract = as<IRFieldExtract>(user);
6464
builder.setInsertBefore(fieldExtract);
6565
auto fieldAddr = builder.emitFieldAddress(
6666
fieldExtract->getBase(),
@@ -73,8 +73,7 @@ struct TransformParamsToConstRefContext
7373
case kIROp_GetElement:
7474
{
7575
// Transform the IRGetElement into a IRGetElementPtr
76-
auto getElement = as<IRGetElement>(use->getUser());
77-
76+
auto getElement = as<IRGetElement>(user);
7877
builder.setInsertBefore(getElement);
7978
auto elemAddr = builder.emitElementAddress(
8079
getElement->getBase(),
@@ -111,14 +110,8 @@ struct TransformParamsToConstRefContext
111110
List<IRInst*> newArgs;
112111

113112
// Transform arguments to match the updated-parameter
114-
IRParam* param = func->getFirstParam();
115113
UInt i = 0;
116-
auto iterate = [&]()
117-
{
118-
param = param->getNextParam();
119-
i++;
120-
};
121-
for (; param; iterate())
114+
for (IRParam* param = func->getFirstParam(); param; param = param->getNextParam(), i++)
122115
{
123116
auto arg = call->getArg(i);
124117
if (!updatedParams.contains(param))
@@ -179,11 +172,191 @@ struct TransformParamsToConstRefContext
179172
return true;
180173
}
181174

175+
// Eliminate unnecessary load+store pairs that can be optimized away
176+
// Only optimize patterns related to the updated parameters
177+
void eliminateLoadStorePairs(IRFunc* func, HashSet<IRParam*>& updatedParams)
178+
{
179+
List<IRInst*> toRemove;
180+
181+
// Look for patterns: load(ptr) followed by store(var, load_result)
182+
// This can be optimized to direct pointer usage
183+
for (auto block : func->getBlocks())
184+
{
185+
for (auto inst : block->getChildren())
186+
{
187+
if (auto storeInst = as<IRStore>(inst))
188+
{
189+
auto storedValue = storeInst->getVal();
190+
auto destPtr = storeInst->getPtr();
191+
192+
// Check if we're storing a load result
193+
if (auto loadInst = as<IRLoad>(storedValue))
194+
{
195+
auto loadedPtr = loadInst->getPtr();
196+
197+
// Only optimize if the loaded pointer is related to our updated parameters
198+
bool isRelatedToUpdatedParam = false;
199+
for (auto param : updatedParams)
200+
{
201+
if (loadedPtr == param)
202+
{
203+
isRelatedToUpdatedParam = true;
204+
break;
205+
}
206+
}
207+
208+
if (!isRelatedToUpdatedParam)
209+
continue;
210+
211+
// IMPORTANT: Only optimize if destPtr is a variable (kIROp_Var)
212+
// Don't optimize stores to buffer elements or other meaningful destinations
213+
if (!as<IRVar>(destPtr))
214+
continue;
215+
216+
// Check if this load has only one use (this store)
217+
bool loadHasOnlyOneUse = true;
218+
UInt useCount = 0;
219+
for (auto use = loadInst->firstUse; use; use = use->nextUse)
220+
{
221+
useCount++;
222+
if (useCount > 1 || use->getUser() != storeInst)
223+
{
224+
loadHasOnlyOneUse = false;
225+
break;
226+
}
227+
}
228+
229+
if (loadHasOnlyOneUse)
230+
{
231+
// Replace all uses of destPtr with loadedPtr
232+
destPtr->replaceUsesWith(loadedPtr);
233+
234+
// Mark both instructions for removal
235+
toRemove.add(storeInst);
236+
toRemove.add(loadInst);
237+
238+
// Also remove the variable if it's only used by this store
239+
if (auto varInst = as<IRVar>(destPtr))
240+
{
241+
bool varHasOnlyStoreUse = true;
242+
for (auto use = varInst->firstUse; use; use = use->nextUse)
243+
{
244+
if (use->getUser() != storeInst)
245+
{
246+
varHasOnlyStoreUse = false;
247+
break;
248+
}
249+
}
250+
if (varHasOnlyStoreUse)
251+
{
252+
toRemove.add(varInst);
253+
}
254+
}
255+
256+
changed = true;
257+
}
258+
}
259+
}
260+
}
261+
}
262+
263+
// Remove marked instructions
264+
for (auto inst : toRemove)
265+
{
266+
if (inst->getParent())
267+
{
268+
inst->removeAndDeallocate();
269+
}
270+
}
271+
}
272+
273+
// Eliminate load+store pairs for entry point functions (without specific updatedParams)
274+
void eliminateLoadStorePairsForEntryPoint(IRFunc* func)
275+
{
276+
List<IRInst*> toRemove;
277+
278+
// Look for patterns: load(ptr) followed by store(var, load_result)
279+
// This can be optimized to direct pointer usage
280+
for (auto block : func->getBlocks())
281+
{
282+
for (auto inst : block->getChildren())
283+
{
284+
if (auto storeInst = as<IRStore>(inst))
285+
{
286+
auto storedValue = storeInst->getVal();
287+
auto destPtr = storeInst->getPtr();
288+
289+
// Check if we're storing a load result
290+
if (auto loadInst = as<IRLoad>(storedValue))
291+
{
292+
auto loadedPtr = loadInst->getPtr();
293+
294+
// IMPORTANT: Only optimize if destPtr is a variable (kIROp_Var)
295+
// Don't optimize stores to buffer elements or other meaningful destinations
296+
if (!as<IRVar>(destPtr))
297+
continue;
298+
299+
// Check if this load has only one use (this store)
300+
bool loadHasOnlyOneUse = true;
301+
UInt useCount = 0;
302+
for (auto use = loadInst->firstUse; use; use = use->nextUse)
303+
{
304+
useCount++;
305+
if (useCount > 1 || use->getUser() != storeInst)
306+
{
307+
loadHasOnlyOneUse = false;
308+
break;
309+
}
310+
}
311+
312+
if (loadHasOnlyOneUse)
313+
{
314+
// Replace all uses of destPtr with loadedPtr
315+
destPtr->replaceUsesWith(loadedPtr);
316+
317+
// Mark both instructions for removal
318+
toRemove.add(storeInst);
319+
toRemove.add(loadInst);
320+
321+
// Also remove the variable if it's only used by this store
322+
if (auto varInst = as<IRVar>(destPtr))
323+
{
324+
bool varHasOnlyStoreUse = true;
325+
for (auto use = varInst->firstUse; use; use = use->nextUse)
326+
{
327+
if (use->getUser() != storeInst)
328+
{
329+
varHasOnlyStoreUse = false;
330+
break;
331+
}
332+
}
333+
if (varHasOnlyStoreUse)
334+
{
335+
toRemove.add(varInst);
336+
}
337+
}
338+
339+
changed = true;
340+
}
341+
}
342+
}
343+
}
344+
}
345+
346+
// Remove marked instructions
347+
for (auto inst : toRemove)
348+
{
349+
if (inst->getParent())
350+
{
351+
inst->removeAndDeallocate();
352+
}
353+
}
354+
}
355+
182356
// Process a single function
183357
void processFunc(IRFunc* func)
184358
{
185359
HashSet<IRParam*> updatedParams;
186-
bool hasTransformedParams = false;
187360

188361
// First pass: Transform parameter types
189362
for (auto param = func->getFirstParam(); param; param = param->getNextParam())
@@ -203,13 +376,12 @@ struct TransformParamsToConstRefContext
203376
auto constRefType = builder.getConstRefType(paramType, AddressSpace::ThreadLocal);
204377
param->setFullType(constRefType);
205378

206-
hasTransformedParams = true;
207379
changed = true;
208380
updatedParams.add(param);
209381
}
210382
}
211383

212-
if (!hasTransformedParams)
384+
if (updatedParams.getCount() == 0)
213385
{
214386
return;
215387
}
@@ -219,6 +391,9 @@ struct TransformParamsToConstRefContext
219391

220392
// Third pass: Update call sites
221393
updateCallSites(func, updatedParams);
394+
395+
// Optimization pass: Remove unnecessary load+store pairs created by updateCallSites
396+
eliminateLoadStorePairs(func, updatedParams);
222397
}
223398

224399
void addFuncsToCallListInTopologicalOrder(
@@ -274,6 +449,22 @@ struct TransformParamsToConstRefContext
274449
processFunc(func);
275450
}
276451

452+
// Handle entry point functions separately - they don't get processed by processFunc
453+
// but they still need load/store optimization for parameters that come from global parameters
454+
for (auto inst = module->getModuleInst()->getFirstChild(); inst; inst = inst->getNextInst())
455+
{
456+
auto func = as<IRFunc>(inst);
457+
if (!func || !func->isDefinition())
458+
continue;
459+
460+
// Only process entry point functions that weren't already processed
461+
if (!shouldProcessFunction(func))
462+
{
463+
// For entry point functions, use a broader optimization since we don't have specific updatedParams
464+
eliminateLoadStorePairsForEntryPoint(func);
465+
}
466+
}
467+
277468
return SLANG_OK;
278469
}
279470
};
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
//TEST:SIMPLE(filecheck=CHECK): -target cuda -stage compute -entry computeMain
2+
3+
// Test for GitHub issue #8412: Avoid unnecessary copy of function parameters
4+
// This test verifies that member function calls on ParameterBlock types
5+
// do not generate unnecessary temporary copies in CUDA output.
6+
7+
struct LargeStruct
8+
{
9+
float4 data[16]; // 256 bytes - large enough to make copying expensive
10+
11+
// Non-mutating method that should not create unnecessary copies
12+
float getValue(int index)
13+
{
14+
return data[index].x + data[index].y;
15+
}
16+
17+
// Another method to test multiple member function calls
18+
float4 getVector(int index)
19+
{
20+
return data[index];
21+
}
22+
}
23+
24+
ParameterBlock<LargeStruct> pb;
25+
RWStructuredBuffer<float> outputBuffer;
26+
27+
void myFunc(LargeStruct tmp)
28+
{
29+
// These calls should NOT generate temporary copies in CUDA output
30+
float result1 = tmp.getValue(0);
31+
float4 result2 = tmp.getVector(1);
32+
33+
outputBuffer[0] = result1;
34+
outputBuffer[1] = result2.x;
35+
}
36+
37+
[numthreads(1, 1, 1)]
38+
void computeMain(uint3 id: SV_DispatchThreadID)
39+
{
40+
myFunc(pb);
41+
}
42+
43+
// CHECK: __device__ float LargeStruct_getValue
44+
// CHECK: __device__ float4 LargeStruct_getVector
45+
// CHECK: __device__ void myFunc
46+
// CHECK-NOT: LargeStruct{{.*}} {{.*}} = *
47+
// CHECK: LargeStruct_getValue
48+
// CHECK-NOT: LargeStruct{{.*}} {{.*}} = *
49+
// CHECK: LargeStruct_getVector

tests/cuda/copy-elision-this-2.slang

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -70,8 +70,8 @@ void modify(inout int data[10])
7070
// CUDA: notDirectlyUsingParam{{.*}}Array{{.*}}*{{.*}}data
7171
int notDirectlyUsingParam(int data[10], int val)
7272
{
73-
// ensure we create a temporary for the array
74-
// CUDA: FixedArray{{.*}};
73+
// do not create a temporary for the array
74+
// CUDA-NOT: FixedArray{{.*}};
7575
modify(data);
7676
return data[1] + val;
7777
}
@@ -114,7 +114,7 @@ void computeMain()
114114
// not directly using param
115115
int notDirectlyUsingParamVal = notDirectlyUsingParam(val, input[1]);
116116

117-
output[0] =
117+
output[0] =
118118
structVal == 3 &&
119119
globalParamStructVal == 3 &&
120120
nestedStructVal == 3 &&
@@ -138,4 +138,4 @@ void computeMain()
138138
//output[9] = notDirectlyUsingParamVal;
139139
}
140140

141-
//BUF: 1
141+
//BUF: 1

tests/metal/vector-get-element-ptr.slang

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,12 @@ void modify(inout int v)
1414
void computeMain(int3 v : SV_DispatchThreadID)
1515
{
1616
int3 u = v;
17-
// CHECK: int [[TEMP:[a-zA-Z0-9_]+]] = u{{.*}}.x;
18-
// CHECK: modify{{.*}}(&[[TEMP]])
19-
// CHECK: u{{.*}}.x = [[TEMP]];
17+
// CHECK: thread int3 [[TEMP:[a-zA-Z0-9_]+]] = int3(v{{.*}};
18+
// CHECK: modify{{.*}}(&[[TEMP]][int(0)])
2019

2120
modify(u.x);
21+
22+
// CHECK: outputBuffer{{.*}} = [[TEMP]].x + [[TEMP]].y;
2223
// BUF: 2
2324
outputBuffer[0] = u.x + u.y;
24-
}
25+
}

0 commit comments

Comments
 (0)