Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions source/slang/slang-ir-dce.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,4 +35,5 @@ bool shouldInstBeLiveIfParentIsLive(IRInst* inst, IRDeadCodeEliminationOptions o
bool isWeakReferenceOperand(IRInst* inst, UInt operandIndex);

bool trimOptimizableTypes(IRModule* module);

} // namespace Slang
197 changes: 197 additions & 0 deletions source/slang/slang-ir-redundancy-removal.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,202 @@ bool removeRedundancy(IRModule* module, bool hoistLoopInvariantInsts)
return changed;
}

bool isAddressMutable(IRInst* inst)
{
auto rootType = getRootAddr(inst)->getDataType();
switch (rootType->getOp())
{
case kIROp_ParameterBlockType:
case kIROp_ConstantBufferType:
case kIROp_ConstRefType:
return false; // immutable

// We should consider StructuredBuffer as mutable by default, since the resources may alias.
// There could be anotherRWStructuredBuffer pointing to the same memory location as the
// structured buffer.
case kIROp_StructuredBufferLoad:
case kIROp_GetStructuredBufferPtr:
return true; // mutable
}

// Similarly, IRPtrTypeBase should also be considered writable always,
// because there can be aliasing.

return true; // mutable
}

/// Eliminate redundant temporary variable copies in load-store patterns.
/// This optimization looks for patterns where a value is loaded from memory
/// and immediately stored to a temporary variable, which is then only used
/// in read-only contexts. In such cases, the temporary variable and the
/// load-store indirection can be eliminated by using the original memory
/// location directly.
/// Returns true if any changes were made.
static bool eliminateRedundantTemporaryCopyInFunc(IRFunc* func)
{
// Consider the following IR pattern:
// ```
// let %temp = var
// let %value = load(%sourcePtr)
// store(%temp, %value)
// ```
// We can replace "%temp" with "%sourcePtr" without the load and store indirection
// if "%temp" is used only in read-only contexts.

bool overallChanged = false;
for (bool changed = true; changed;)
{
changed = false;

HashSet<IRInst*> toRemove;

for (auto block : func->getBlocks())
{
for (auto blockInst : block->getChildren())
{
auto storeInst = as<IRStore>(blockInst);
if (!storeInst)
{
// We are interested only in IRStore.
continue;
}

auto storedValue = storeInst->getVal();
auto destPtr = storeInst->getPtr();

if (destPtr->getOp() != kIROp_Var)
{
// Only optimize temporary variable.
// Don't optimize stores to permanent memory locations.
continue;
}

// Check if we're storing a load result
auto loadInst = as<IRLoad>(storedValue);
if (!loadInst)
{
// Skip because only IRLoad is expected for the optimization.
continue;
}

auto loadPtr = loadInst->getPtr();

if (isAddressMutable(loadPtr))
{
// If the input is mutable, we cannot optimize,
// because any function calls may alter the content of the input
// and we cannot replace the temporary copy with a memory pointer.
continue;
}

// Storing address-sapce for later use.
AddressSpace loadAddressSpace = AddressSpace::Generic;
if (auto rootPtrType = as<IRPtrTypeBase>(getRootAddr(loadPtr)->getDataType()))
{
loadAddressSpace = rootPtrType->getAddressSpace();
}

// Do not optimize loads from semantic parameters because some semantics have
// builtin types that are vector types but pretend to be scalar types (e.g.,
// SV_DispatchThreadID is used as 'int id' but maps to 'float3
// gl_GlobalInvocationID'). The legalization step must remove the load instruction
// to maintain this pretense, which breaks our load/store optimization assumptions.
// Skip optimization when loading from semantics to let legalization handle the load
// removal.
if (auto param = as<IRParam>(loadPtr))
if (param->findDecoration<IRSemanticDecoration>())
continue;

// Check all uses of the destination variable
for (auto use = destPtr->firstUse; use; use = use->nextUse)
{
auto user = use->getUser();
if (user == storeInst)
{
// Skip the store itself
continue; // check the next use
}

if (as<IRStore>(use->getUser()))
{
// We cannot optimize when the variable is reused
// with another store.
goto unsafeToOptimize;
}

if (as<IRLoad>(user))
{
// Allow loads because IRLoad is read-only operation
continue; // Check the next use
}

// For function calls, check if the pointer is treated as immutable.
if (auto call = as<IRCall>(user))
{
auto callee = call->getCallee();
auto funcInst = as<IRFunc>(callee);
if (!funcInst)
goto unsafeToOptimize;

UIndex argIndex = (UIndex)(use - call->getArgs());
SLANG_ASSERT(argIndex < call->getArgCount());
SLANG_ASSERT(call->getArg(argIndex) == destPtr);

IRParam* param = funcInst->getFirstParam();
for (UIndex i = 0; i < argIndex; i++)
{
if (param)
param = param->getNextParam();
}
if (nullptr == param)
goto unsafeToOptimize; // IRFunc might be incomplete yet

if (auto paramPtrType = as<IRConstRefType>(param->getFullType()))
{
if (paramPtrType->getAddressSpace() != loadAddressSpace)
goto unsafeToOptimize; // incompatible address space
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This check isn't necessary.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was needed on MacOS where "constant*" point was not compatible with "thread*" pointer.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The test is tests/compute/cbuffer-legalize.slang, and the emitted Metal shader is following, as a quick reference.

#include <metal_stdlib>
#include <metal_math>
#include <metal_texture>
using namespace metal;
struct P_0
{
    uint4 c_0;
};

float4 test_0(const P_0 thread* p_0, texture2d<float, access::sample> p_t_0, sampler p_s_0)
{
    return ((p_t_0).sample((p_s_0), (float2(0.0) ), level((0.0)))) + float4(p_0->c_0);
}

struct SLANG_ParameterGroup_C_0
{
    P_0 p_1;
};

struct KernelContext_0
{
    SLANG_ParameterGroup_C_0 constant* C_0;
    texture2d<float, access::sample> C_p_t_0;
    sampler C_p_s_0;
    float device* outputBuffer_0;
};

[[kernel]] void computeMain(uint3 dispatchThreadID_0 [[thread_position_in_grid]], SLANG_ParameterGroup_C_0 constant* C_1 [[buffer(0)]], texture2d<float, access::sample> C_p_t_1 [[texture(0)]], sampler C_p_s_1 [[sampler(0)]], float device* outputBuffer_1 [[buffer(1)]])
{
    KernelContext_0 kernelContext_0;
    (&kernelContext_0)->C_0 = C_1;
    (&kernelContext_0)->C_p_t_0 = C_p_t_1;
    (&kernelContext_0)->C_p_s_0 = C_p_s_1;
    (&kernelContext_0)->outputBuffer_0 = outputBuffer_1;
    P_0 _S1 = (&kernelContext_0)->C_0->p_1;
    float4 _S2 = test_0(&_S1, (&kernelContext_0)->C_p_t_0, (&kernelContext_0)->C_p_s_0);
    *(outputBuffer_1+int(0)) = _S2.x;
    *((&kernelContext_0)->outputBuffer_0+int(1)) = _S2.y;
    *((&kernelContext_0)->outputBuffer_0+int(2)) = _S2.z;
    *((&kernelContext_0)->outputBuffer_0+int(3)) = _S2.w;
    return;
}

The trouble is when test_0() was called as following,

    float4 _S2 = test_0(&((&kernelContext_0)->C_0->p_1), (&kernelContext_0)->C_p_t_0, (&kernelContext_0)->C_p_s_0);

The constant* was not compatible with thread*.


continue; // safe so far and check the next use
}
goto unsafeToOptimize; // must be const-ref
}

// TODO: there might be more cases that is safe to optimize
// We need to add more cases here as needed.

// If we get here, the pointer is used with an unexpected IR.
goto unsafeToOptimize;
}

// If we get here, all uses are safe to optimize.

// Replace all uses of destPtr with loadedPtr
destPtr->replaceUsesWith(loadPtr);

// Mark instructions for removal
toRemove.add(storeInst);
toRemove.add(destPtr);

// Note: loadInst might be still in use.
// We need to rely on DCE to delete it if unused.

changed = true;
overallChanged = true;

unsafeToOptimize:;
}
}

// Remove marked instructions
for (auto instToRemove : toRemove)
{
instToRemove->removeAndDeallocate();
}
}

return overallChanged;
}

bool removeRedundancyInFunc(IRGlobalValueWithCode* func, bool hoistLoopInvariantInsts)
{
auto root = func->getFirstBlock();
Expand Down Expand Up @@ -163,6 +359,7 @@ bool removeRedundancyInFunc(IRGlobalValueWithCode* func, bool hoistLoopInvariant
if (auto normalFunc = as<IRFunc>(func))
{
result |= eliminateRedundantLoadStore(normalFunc);
result |= eliminateRedundantTemporaryCopyInFunc(normalFunc);
}
return result;
}
Expand Down
17 changes: 4 additions & 13 deletions source/slang/slang-ir-transform-params-to-constref.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ struct TransformParamsToConstRefContext
case kIROp_FieldExtract:
{
// Transform the IRFieldExtract into a IRFieldAddress
auto fieldExtract = as<IRFieldExtract>(use->getUser());
auto fieldExtract = as<IRFieldExtract>(user);
builder.setInsertBefore(fieldExtract);
auto fieldAddr = builder.emitFieldAddress(
fieldExtract->getBase(),
Expand All @@ -73,8 +73,7 @@ struct TransformParamsToConstRefContext
case kIROp_GetElement:
{
// Transform the IRGetElement into a IRGetElementPtr
auto getElement = as<IRGetElement>(use->getUser());

auto getElement = as<IRGetElement>(user);
builder.setInsertBefore(getElement);
auto elemAddr = builder.emitElementAddress(
getElement->getBase(),
Expand Down Expand Up @@ -111,14 +110,8 @@ struct TransformParamsToConstRefContext
List<IRInst*> newArgs;

// Transform arguments to match the updated-parameter
IRParam* param = func->getFirstParam();
UInt i = 0;
auto iterate = [&]()
{
param = param->getNextParam();
i++;
};
for (; param; iterate())
for (IRParam* param = func->getFirstParam(); param; param = param->getNextParam(), i++)
{
auto arg = call->getArg(i);
if (!updatedParams.contains(param))
Expand Down Expand Up @@ -183,7 +176,6 @@ struct TransformParamsToConstRefContext
void processFunc(IRFunc* func)
{
HashSet<IRParam*> updatedParams;
bool hasTransformedParams = false;

// First pass: Transform parameter types
for (auto param = func->getFirstParam(); param; param = param->getNextParam())
Expand All @@ -203,13 +195,12 @@ struct TransformParamsToConstRefContext
auto constRefType = builder.getConstRefType(paramType, AddressSpace::ThreadLocal);
param->setFullType(constRefType);

hasTransformedParams = true;
changed = true;
updatedParams.add(param);
}
}

if (!hasTransformedParams)
if (updatedParams.getCount() == 0)
{
return;
}
Expand Down
6 changes: 3 additions & 3 deletions tests/metal/vector-get-element-ptr.slang
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,11 @@ void modify(inout int v)
void computeMain(int3 v : SV_DispatchThreadID)
{
int3 u = v;
// CHECK: int [[TEMP:[a-zA-Z0-9_]+]] = u{{.*}}.x;
// CHECK: int [[TEMP:[a-zA-Z0-9_]+]] = [[OUT:[a-zA-Z0-9_]+]].x;
// CHECK: modify{{.*}}(&[[TEMP]])
// CHECK: u{{.*}}.x = [[TEMP]];
// CHECK: [[OUT]].x = [[TEMP]];

modify(u.x);
// BUF: 2
outputBuffer[0] = u.x + u.y;
}
}