Skip to content

Commit 9c2024a

Browse files
jkwak-workslangbot
andauthored
Remove unnecessary Load and Store pair (#8433)
This commit removes unnecessary Load and Store pairs in IR. When the IR is like ``` let %1 = var let %2 = load(%ptr) store(%1 %2) ``` This PR will replace all uses of %1 with %ptr. And the load and store instructions will be removed. But I found that there can be cases where %2 might be still used later in other IRs. For these cases, the removal of load instruction relies on DCE. --------- Co-authored-by: slangbot <[email protected]>
1 parent 979e16a commit 9c2024a

File tree

4 files changed

+205
-16
lines changed

4 files changed

+205
-16
lines changed

source/slang/slang-ir-dce.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,4 +35,5 @@ bool shouldInstBeLiveIfParentIsLive(IRInst* inst, IRDeadCodeEliminationOptions o
3535
bool isWeakReferenceOperand(IRInst* inst, UInt operandIndex);
3636

3737
bool trimOptimizableTypes(IRModule* module);
38+
3839
} // namespace Slang

source/slang/slang-ir-redundancy-removal.cpp

Lines changed: 197 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,202 @@ bool removeRedundancy(IRModule* module, bool hoistLoopInvariantInsts)
126126
return changed;
127127
}
128128

129+
bool isAddressMutable(IRInst* inst)
130+
{
131+
auto rootType = getRootAddr(inst)->getDataType();
132+
switch (rootType->getOp())
133+
{
134+
case kIROp_ParameterBlockType:
135+
case kIROp_ConstantBufferType:
136+
case kIROp_ConstRefType:
137+
return false; // immutable
138+
139+
// We should consider StructuredBuffer as mutable by default, since the resources may alias.
140+
// There could be anotherRWStructuredBuffer pointing to the same memory location as the
141+
// structured buffer.
142+
case kIROp_StructuredBufferLoad:
143+
case kIROp_GetStructuredBufferPtr:
144+
return true; // mutable
145+
}
146+
147+
// Similarly, IRPtrTypeBase should also be considered writable always,
148+
// because there can be aliasing.
149+
150+
return true; // mutable
151+
}
152+
153+
/// Eliminate redundant temporary variable copies in load-store patterns.
154+
/// This optimization looks for patterns where a value is loaded from memory
155+
/// and immediately stored to a temporary variable, which is then only used
156+
/// in read-only contexts. In such cases, the temporary variable and the
157+
/// load-store indirection can be eliminated by using the original memory
158+
/// location directly.
159+
/// Returns true if any changes were made.
160+
static bool eliminateRedundantTemporaryCopyInFunc(IRFunc* func)
161+
{
162+
// Consider the following IR pattern:
163+
// ```
164+
// let %temp = var
165+
// let %value = load(%sourcePtr)
166+
// store(%temp, %value)
167+
// ```
168+
// We can replace "%temp" with "%sourcePtr" without the load and store indirection
169+
// if "%temp" is used only in read-only contexts.
170+
171+
bool overallChanged = false;
172+
for (bool changed = true; changed;)
173+
{
174+
changed = false;
175+
176+
HashSet<IRInst*> toRemove;
177+
178+
for (auto block : func->getBlocks())
179+
{
180+
for (auto blockInst : block->getChildren())
181+
{
182+
auto storeInst = as<IRStore>(blockInst);
183+
if (!storeInst)
184+
{
185+
// We are interested only in IRStore.
186+
continue;
187+
}
188+
189+
auto storedValue = storeInst->getVal();
190+
auto destPtr = storeInst->getPtr();
191+
192+
if (destPtr->getOp() != kIROp_Var)
193+
{
194+
// Only optimize temporary variable.
195+
// Don't optimize stores to permanent memory locations.
196+
continue;
197+
}
198+
199+
// Check if we're storing a load result
200+
auto loadInst = as<IRLoad>(storedValue);
201+
if (!loadInst)
202+
{
203+
// Skip because only IRLoad is expected for the optimization.
204+
continue;
205+
}
206+
207+
auto loadPtr = loadInst->getPtr();
208+
209+
if (isAddressMutable(loadPtr))
210+
{
211+
// If the input is mutable, we cannot optimize,
212+
// because any function calls may alter the content of the input
213+
// and we cannot replace the temporary copy with a memory pointer.
214+
continue;
215+
}
216+
217+
// Storing address-sapce for later use.
218+
AddressSpace loadAddressSpace = AddressSpace::Generic;
219+
if (auto rootPtrType = as<IRPtrTypeBase>(getRootAddr(loadPtr)->getDataType()))
220+
{
221+
loadAddressSpace = rootPtrType->getAddressSpace();
222+
}
223+
224+
// Do not optimize loads from semantic parameters because some semantics have
225+
// builtin types that are vector types but pretend to be scalar types (e.g.,
226+
// SV_DispatchThreadID is used as 'int id' but maps to 'float3
227+
// gl_GlobalInvocationID'). The legalization step must remove the load instruction
228+
// to maintain this pretense, which breaks our load/store optimization assumptions.
229+
// Skip optimization when loading from semantics to let legalization handle the load
230+
// removal.
231+
if (auto param = as<IRParam>(loadPtr))
232+
if (param->findDecoration<IRSemanticDecoration>())
233+
continue;
234+
235+
// Check all uses of the destination variable
236+
for (auto use = destPtr->firstUse; use; use = use->nextUse)
237+
{
238+
auto user = use->getUser();
239+
if (user == storeInst)
240+
{
241+
// Skip the store itself
242+
continue; // check the next use
243+
}
244+
245+
if (as<IRStore>(use->getUser()))
246+
{
247+
// We cannot optimize when the variable is reused
248+
// with another store.
249+
goto unsafeToOptimize;
250+
}
251+
252+
if (as<IRLoad>(user))
253+
{
254+
// Allow loads because IRLoad is read-only operation
255+
continue; // Check the next use
256+
}
257+
258+
// For function calls, check if the pointer is treated as immutable.
259+
if (auto call = as<IRCall>(user))
260+
{
261+
auto callee = call->getCallee();
262+
auto funcInst = as<IRFunc>(callee);
263+
if (!funcInst)
264+
goto unsafeToOptimize;
265+
266+
UIndex argIndex = (UIndex)(use - call->getArgs());
267+
SLANG_ASSERT(argIndex < call->getArgCount());
268+
SLANG_ASSERT(call->getArg(argIndex) == destPtr);
269+
270+
IRParam* param = funcInst->getFirstParam();
271+
for (UIndex i = 0; i < argIndex; i++)
272+
{
273+
if (param)
274+
param = param->getNextParam();
275+
}
276+
if (nullptr == param)
277+
goto unsafeToOptimize; // IRFunc might be incomplete yet
278+
279+
if (auto paramPtrType = as<IRConstRefType>(param->getFullType()))
280+
{
281+
if (paramPtrType->getAddressSpace() != loadAddressSpace)
282+
goto unsafeToOptimize; // incompatible address space
283+
284+
continue; // safe so far and check the next use
285+
}
286+
goto unsafeToOptimize; // must be const-ref
287+
}
288+
289+
// TODO: there might be more cases that is safe to optimize
290+
// We need to add more cases here as needed.
291+
292+
// If we get here, the pointer is used with an unexpected IR.
293+
goto unsafeToOptimize;
294+
}
295+
296+
// If we get here, all uses are safe to optimize.
297+
298+
// Replace all uses of destPtr with loadedPtr
299+
destPtr->replaceUsesWith(loadPtr);
300+
301+
// Mark instructions for removal
302+
toRemove.add(storeInst);
303+
toRemove.add(destPtr);
304+
305+
// Note: loadInst might be still in use.
306+
// We need to rely on DCE to delete it if unused.
307+
308+
changed = true;
309+
overallChanged = true;
310+
311+
unsafeToOptimize:;
312+
}
313+
}
314+
315+
// Remove marked instructions
316+
for (auto instToRemove : toRemove)
317+
{
318+
instToRemove->removeAndDeallocate();
319+
}
320+
}
321+
322+
return overallChanged;
323+
}
324+
129325
bool removeRedundancyInFunc(IRGlobalValueWithCode* func, bool hoistLoopInvariantInsts)
130326
{
131327
auto root = func->getFirstBlock();
@@ -163,6 +359,7 @@ bool removeRedundancyInFunc(IRGlobalValueWithCode* func, bool hoistLoopInvariant
163359
if (auto normalFunc = as<IRFunc>(func))
164360
{
165361
result |= eliminateRedundantLoadStore(normalFunc);
362+
result |= eliminateRedundantTemporaryCopyInFunc(normalFunc);
166363
}
167364
return result;
168365
}

source/slang/slang-ir-transform-params-to-constref.cpp

Lines changed: 4 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ struct TransformParamsToConstRefContext
6060
case kIROp_FieldExtract:
6161
{
6262
// Transform the IRFieldExtract into a IRFieldAddress
63-
auto fieldExtract = as<IRFieldExtract>(use->getUser());
63+
auto fieldExtract = as<IRFieldExtract>(user);
6464
builder.setInsertBefore(fieldExtract);
6565
auto fieldAddr = builder.emitFieldAddress(
6666
fieldExtract->getBase(),
@@ -73,8 +73,7 @@ struct TransformParamsToConstRefContext
7373
case kIROp_GetElement:
7474
{
7575
// Transform the IRGetElement into a IRGetElementPtr
76-
auto getElement = as<IRGetElement>(use->getUser());
77-
76+
auto getElement = as<IRGetElement>(user);
7877
builder.setInsertBefore(getElement);
7978
auto elemAddr = builder.emitElementAddress(
8079
getElement->getBase(),
@@ -111,14 +110,8 @@ struct TransformParamsToConstRefContext
111110
List<IRInst*> newArgs;
112111

113112
// Transform arguments to match the updated-parameter
114-
IRParam* param = func->getFirstParam();
115113
UInt i = 0;
116-
auto iterate = [&]()
117-
{
118-
param = param->getNextParam();
119-
i++;
120-
};
121-
for (; param; iterate())
114+
for (IRParam* param = func->getFirstParam(); param; param = param->getNextParam(), i++)
122115
{
123116
auto arg = call->getArg(i);
124117
if (!updatedParams.contains(param))
@@ -183,7 +176,6 @@ struct TransformParamsToConstRefContext
183176
void processFunc(IRFunc* func)
184177
{
185178
HashSet<IRParam*> updatedParams;
186-
bool hasTransformedParams = false;
187179

188180
// First pass: Transform parameter types
189181
for (auto param = func->getFirstParam(); param; param = param->getNextParam())
@@ -203,13 +195,12 @@ struct TransformParamsToConstRefContext
203195
auto constRefType = builder.getConstRefType(paramType, AddressSpace::ThreadLocal);
204196
param->setFullType(constRefType);
205197

206-
hasTransformedParams = true;
207198
changed = true;
208199
updatedParams.add(param);
209200
}
210201
}
211202

212-
if (!hasTransformedParams)
203+
if (updatedParams.getCount() == 0)
213204
{
214205
return;
215206
}

tests/metal/vector-get-element-ptr.slang

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,11 @@ void modify(inout int v)
1414
void computeMain(int3 v : SV_DispatchThreadID)
1515
{
1616
int3 u = v;
17-
// CHECK: int [[TEMP:[a-zA-Z0-9_]+]] = u{{.*}}.x;
17+
// CHECK: int [[TEMP:[a-zA-Z0-9_]+]] = [[OUT:[a-zA-Z0-9_]+]].x;
1818
// CHECK: modify{{.*}}(&[[TEMP]])
19-
// CHECK: u{{.*}}.x = [[TEMP]];
19+
// CHECK: [[OUT]].x = [[TEMP]];
2020

2121
modify(u.x);
2222
// BUF: 2
2323
outputBuffer[0] = u.x + u.y;
24-
}
24+
}

0 commit comments

Comments
 (0)