diff --git a/source/slang/slang-ir-dce.h b/source/slang/slang-ir-dce.h index 869ebc36ae..4df491881a 100644 --- a/source/slang/slang-ir-dce.h +++ b/source/slang/slang-ir-dce.h @@ -35,4 +35,5 @@ bool shouldInstBeLiveIfParentIsLive(IRInst* inst, IRDeadCodeEliminationOptions o bool isWeakReferenceOperand(IRInst* inst, UInt operandIndex); bool trimOptimizableTypes(IRModule* module); + } // namespace Slang diff --git a/source/slang/slang-ir-redundancy-removal.cpp b/source/slang/slang-ir-redundancy-removal.cpp index c52ab7bcae..94bb6b67c6 100644 --- a/source/slang/slang-ir-redundancy-removal.cpp +++ b/source/slang/slang-ir-redundancy-removal.cpp @@ -126,6 +126,202 @@ bool removeRedundancy(IRModule* module, bool hoistLoopInvariantInsts) return changed; } +bool isAddressMutable(IRInst* inst) +{ + auto rootType = getRootAddr(inst)->getDataType(); + switch (rootType->getOp()) + { + case kIROp_ParameterBlockType: + case kIROp_ConstantBufferType: + case kIROp_ConstRefType: + return false; // immutable + + // We should consider StructuredBuffer as mutable by default, since the resources may alias. + // There could be anotherRWStructuredBuffer pointing to the same memory location as the + // structured buffer. + case kIROp_StructuredBufferLoad: + case kIROp_GetStructuredBufferPtr: + return true; // mutable + } + + // Similarly, IRPtrTypeBase should also be considered writable always, + // because there can be aliasing. + + return true; // mutable +} + +/// Eliminate redundant temporary variable copies in load-store patterns. +/// This optimization looks for patterns where a value is loaded from memory +/// and immediately stored to a temporary variable, which is then only used +/// in read-only contexts. In such cases, the temporary variable and the +/// load-store indirection can be eliminated by using the original memory +/// location directly. +/// Returns true if any changes were made. +static bool eliminateRedundantTemporaryCopyInFunc(IRFunc* func) +{ + // Consider the following IR pattern: + // ``` + // let %temp = var + // let %value = load(%sourcePtr) + // store(%temp, %value) + // ``` + // We can replace "%temp" with "%sourcePtr" without the load and store indirection + // if "%temp" is used only in read-only contexts. + + bool overallChanged = false; + for (bool changed = true; changed;) + { + changed = false; + + HashSet toRemove; + + for (auto block : func->getBlocks()) + { + for (auto blockInst : block->getChildren()) + { + auto storeInst = as(blockInst); + if (!storeInst) + { + // We are interested only in IRStore. + continue; + } + + auto storedValue = storeInst->getVal(); + auto destPtr = storeInst->getPtr(); + + if (destPtr->getOp() != kIROp_Var) + { + // Only optimize temporary variable. + // Don't optimize stores to permanent memory locations. + continue; + } + + // Check if we're storing a load result + auto loadInst = as(storedValue); + if (!loadInst) + { + // Skip because only IRLoad is expected for the optimization. + continue; + } + + auto loadPtr = loadInst->getPtr(); + + if (isAddressMutable(loadPtr)) + { + // If the input is mutable, we cannot optimize, + // because any function calls may alter the content of the input + // and we cannot replace the temporary copy with a memory pointer. + continue; + } + + // Storing address-sapce for later use. + AddressSpace loadAddressSpace = AddressSpace::Generic; + if (auto rootPtrType = as(getRootAddr(loadPtr)->getDataType())) + { + loadAddressSpace = rootPtrType->getAddressSpace(); + } + + // Do not optimize loads from semantic parameters because some semantics have + // builtin types that are vector types but pretend to be scalar types (e.g., + // SV_DispatchThreadID is used as 'int id' but maps to 'float3 + // gl_GlobalInvocationID'). The legalization step must remove the load instruction + // to maintain this pretense, which breaks our load/store optimization assumptions. + // Skip optimization when loading from semantics to let legalization handle the load + // removal. + if (auto param = as(loadPtr)) + if (param->findDecoration()) + continue; + + // Check all uses of the destination variable + for (auto use = destPtr->firstUse; use; use = use->nextUse) + { + auto user = use->getUser(); + if (user == storeInst) + { + // Skip the store itself + continue; // check the next use + } + + if (as(use->getUser())) + { + // We cannot optimize when the variable is reused + // with another store. + goto unsafeToOptimize; + } + + if (as(user)) + { + // Allow loads because IRLoad is read-only operation + continue; // Check the next use + } + + // For function calls, check if the pointer is treated as immutable. + if (auto call = as(user)) + { + auto callee = call->getCallee(); + auto funcInst = as(callee); + if (!funcInst) + goto unsafeToOptimize; + + UIndex argIndex = (UIndex)(use - call->getArgs()); + SLANG_ASSERT(argIndex < call->getArgCount()); + SLANG_ASSERT(call->getArg(argIndex) == destPtr); + + IRParam* param = funcInst->getFirstParam(); + for (UIndex i = 0; i < argIndex; i++) + { + if (param) + param = param->getNextParam(); + } + if (nullptr == param) + goto unsafeToOptimize; // IRFunc might be incomplete yet + + if (auto paramPtrType = as(param->getFullType())) + { + if (paramPtrType->getAddressSpace() != loadAddressSpace) + goto unsafeToOptimize; // incompatible address space + + continue; // safe so far and check the next use + } + goto unsafeToOptimize; // must be const-ref + } + + // TODO: there might be more cases that is safe to optimize + // We need to add more cases here as needed. + + // If we get here, the pointer is used with an unexpected IR. + goto unsafeToOptimize; + } + + // If we get here, all uses are safe to optimize. + + // Replace all uses of destPtr with loadedPtr + destPtr->replaceUsesWith(loadPtr); + + // Mark instructions for removal + toRemove.add(storeInst); + toRemove.add(destPtr); + + // Note: loadInst might be still in use. + // We need to rely on DCE to delete it if unused. + + changed = true; + overallChanged = true; + + unsafeToOptimize:; + } + } + + // Remove marked instructions + for (auto instToRemove : toRemove) + { + instToRemove->removeAndDeallocate(); + } + } + + return overallChanged; +} + bool removeRedundancyInFunc(IRGlobalValueWithCode* func, bool hoistLoopInvariantInsts) { auto root = func->getFirstBlock(); @@ -163,6 +359,7 @@ bool removeRedundancyInFunc(IRGlobalValueWithCode* func, bool hoistLoopInvariant if (auto normalFunc = as(func)) { result |= eliminateRedundantLoadStore(normalFunc); + result |= eliminateRedundantTemporaryCopyInFunc(normalFunc); } return result; } diff --git a/source/slang/slang-ir-transform-params-to-constref.cpp b/source/slang/slang-ir-transform-params-to-constref.cpp index 8f4bcd0379..9328a1de11 100644 --- a/source/slang/slang-ir-transform-params-to-constref.cpp +++ b/source/slang/slang-ir-transform-params-to-constref.cpp @@ -60,7 +60,7 @@ struct TransformParamsToConstRefContext case kIROp_FieldExtract: { // Transform the IRFieldExtract into a IRFieldAddress - auto fieldExtract = as(use->getUser()); + auto fieldExtract = as(user); builder.setInsertBefore(fieldExtract); auto fieldAddr = builder.emitFieldAddress( fieldExtract->getBase(), @@ -73,8 +73,7 @@ struct TransformParamsToConstRefContext case kIROp_GetElement: { // Transform the IRGetElement into a IRGetElementPtr - auto getElement = as(use->getUser()); - + auto getElement = as(user); builder.setInsertBefore(getElement); auto elemAddr = builder.emitElementAddress( getElement->getBase(), @@ -111,14 +110,8 @@ struct TransformParamsToConstRefContext List newArgs; // Transform arguments to match the updated-parameter - IRParam* param = func->getFirstParam(); UInt i = 0; - auto iterate = [&]() - { - param = param->getNextParam(); - i++; - }; - for (; param; iterate()) + for (IRParam* param = func->getFirstParam(); param; param = param->getNextParam(), i++) { auto arg = call->getArg(i); if (!updatedParams.contains(param)) @@ -183,7 +176,6 @@ struct TransformParamsToConstRefContext void processFunc(IRFunc* func) { HashSet updatedParams; - bool hasTransformedParams = false; // First pass: Transform parameter types for (auto param = func->getFirstParam(); param; param = param->getNextParam()) @@ -203,13 +195,12 @@ struct TransformParamsToConstRefContext auto constRefType = builder.getConstRefType(paramType, AddressSpace::ThreadLocal); param->setFullType(constRefType); - hasTransformedParams = true; changed = true; updatedParams.add(param); } } - if (!hasTransformedParams) + if (updatedParams.getCount() == 0) { return; } diff --git a/tests/metal/vector-get-element-ptr.slang b/tests/metal/vector-get-element-ptr.slang index af2acabbce..1c616b37e6 100644 --- a/tests/metal/vector-get-element-ptr.slang +++ b/tests/metal/vector-get-element-ptr.slang @@ -14,11 +14,11 @@ void modify(inout int v) void computeMain(int3 v : SV_DispatchThreadID) { int3 u = v; - // CHECK: int [[TEMP:[a-zA-Z0-9_]+]] = u{{.*}}.x; + // CHECK: int [[TEMP:[a-zA-Z0-9_]+]] = [[OUT:[a-zA-Z0-9_]+]].x; // CHECK: modify{{.*}}(&[[TEMP]]) - // CHECK: u{{.*}}.x = [[TEMP]]; + // CHECK: [[OUT]].x = [[TEMP]]; modify(u.x); // BUF: 2 outputBuffer[0] = u.x + u.y; -} \ No newline at end of file +}