Skip to content

Commit 4ead542

Browse files
committed
Refactor based on the feedback
1 parent 244e17e commit 4ead542

File tree

4 files changed

+160
-142
lines changed

4 files changed

+160
-142
lines changed

source/slang/slang-ir-dce.cpp

Lines changed: 149 additions & 132 deletions
Original file line numberDiff line numberDiff line change
@@ -512,182 +512,199 @@ bool trimOptimizableTypes(IRModule* module)
512512
return changed;
513513
}
514514

515-
static bool eliminateLoadStorePairsInFunc(IRFunc* func, HashSet<IRFunc*>& processedFuncs)
515+
static bool eliminateRedundantTemporaryCopyInFunc(IRFunc* func)
516516
{
517-
// Avoid infinite recursion by tracking processed functions
518-
if (processedFuncs.contains(func))
519-
return false;
520-
521-
processedFuncs.add(func);
517+
// Consider the following IR pattern:
518+
// ```
519+
// let %temp = var
520+
// let %value = load(%sourcePtr)
521+
// store(%temp, %value)
522+
// ```
523+
// We can replace "%temp" with "%sourcePtr" without the load and store indirection
524+
// if "%temp" is used only in read-only contexts.
525+
526+
bool overallChanged = false;
527+
for (bool changed = true; changed; )
528+
{
529+
changed = false;
522530

523-
bool changed = false;
524-
HashSet<IRInst*> toRemove;
531+
HashSet<IRInst*> toRemove;
525532

526-
for (auto block : func->getBlocks())
527-
{
528-
for (auto blockInst : block->getChildren())
533+
for (auto block : func->getBlocks())
529534
{
530-
// First, recursively process any function calls in this block
531-
if (auto call = as<IRCall>(blockInst))
535+
for (auto blockInst : block->getChildren())
532536
{
533-
auto callee = call->getCallee();
534-
if (auto calleeFunc = as<IRFunc>(callee))
535-
{
536-
changed |= eliminateLoadStorePairsInFunc(calleeFunc, processedFuncs);
537-
}
538-
}
539-
540-
auto storeInst = as<IRStore>(blockInst);
541-
if (!storeInst)
542-
continue;
537+
auto storeInst = as<IRStore>(blockInst);
538+
if (!storeInst)
539+
continue;
543540

544-
auto storedValue = storeInst->getVal();
545-
auto destPtr = storeInst->getPtr();
541+
auto storedValue = storeInst->getVal();
542+
auto destPtr = storeInst->getPtr();
546543

547-
// Only optimize if destPtr is a variable (kIROp_Var).
548-
// Don't optimize stores to buffer elements or other meaningful destinations.
549-
if (!as<IRVar>(destPtr))
550-
continue;
551-
552-
// Check if we're storing a load result
553-
auto loadInst = as<IRLoad>(storedValue);
554-
if (!loadInst)
555-
continue;
544+
// Only optimize temporary variable.
545+
// Don't optimize stores to permanent memory locations.
546+
if (destPtr->getOp() != kIROp_Var)
547+
continue;
556548

557-
// Do not optimize the primitive types because the legalization step may assume the
558-
// existance of Load instructions when the entry parameters are replaced with the
559-
// builtin variables.
560-
auto loadInstType = loadInst->getDataType();
561-
switch (loadInstType->getOp())
562-
{
563-
case kIROp_Int8Type:
564-
case kIROp_Int16Type:
565-
case kIROp_IntType:
566-
case kIROp_Int64Type:
567-
case kIROp_UInt8Type:
568-
case kIROp_UInt16Type:
569-
case kIROp_UIntType:
570-
case kIROp_UInt64Type:
571-
case kIROp_IntPtrType:
572-
case kIROp_UIntPtrType:
573-
case kIROp_FloatType:
574-
case kIROp_DoubleType:
575-
case kIROp_HalfType:
576-
case kIROp_BoolType:
577-
continue;
578-
}
549+
// Check if we're storing a load result
550+
auto loadInst = as<IRLoad>(storedValue);
551+
if (!loadInst)
552+
continue;
579553

580-
auto loadedPtr = loadInst->getPtr();
554+
auto loadedPtr = loadInst->getPtr();
581555

582-
// Optimize only when the type is immutable (ConstRef).
583-
auto loadedPtrType = loadedPtr->getDataType();
584-
if (!as<IRConstRefType>(loadedPtrType) && !as<IRParameterBlockType>(loadedPtrType) &&
585-
!as<IRConstantBufferType>(loadedPtrType))
586-
continue;
556+
// Do not optimize loads from semantic parameters because some semantics have
557+
// builtin types that are vector types but pretend to be scalar types (e.g.,
558+
// SV_DispatchThreadID is used as 'int id' but maps to 'float3 gl_GlobalInvocationID').
559+
// The legalization step must remove the load instruction to maintain this pretense,
560+
// which breaks our load/store optimization assumptions. Skip optimization when
561+
// loading from semantics to let legalization handle the load removal.
562+
if (auto param = as<IRParam>(loadedPtr))
563+
if (param->findDecoration<IRSemanticDecoration>())
564+
continue;
587565

588-
// We cannot optimize if the variable is passed to a function call that treats the
589-
// parameter as mutable.
590-
bool isSafeToOptimize = true;
566+
// We cannot optimize if the variable is passed to a function call that treats the
567+
// parameter as mutable, or if there are multiple stores to the variable.
591568

592-
// Check if variable is only used in function calls with ConstRef parameters
593-
for (auto use = destPtr->firstUse; use; use = use->nextUse)
594-
{
595-
auto user = use->getUser();
596-
if (user == storeInst)
597-
continue; // Skip the store itself
569+
// If the pointer is re-used with IRStore more than once,
570+
// handling the case will become more complex.
571+
// TODO: we may want to handle the complex cases as well later.
572+
UInt storeCount = 0;
598573

599-
auto call = as<IRCall>(user);
600-
if (!call)
574+
for (auto use = destPtr->firstUse; use; use = use->nextUse)
601575
{
602-
isSafeToOptimize = false;
603-
break;
576+
if (as<IRStore>(use->getUser()))
577+
{
578+
storeCount++;
579+
if (storeCount > 1)
580+
goto unsafeToOptimize;
581+
}
604582
}
605583

606-
auto callee = call->getCallee();
607-
auto funcInst = as<IRFunc>(callee);
608-
if (!funcInst)
584+
// Check all uses of the destination variable
585+
for (auto use = destPtr->firstUse; use; use = use->nextUse)
609586
{
610-
isSafeToOptimize = false;
611-
break;
612-
}
587+
auto user = use->getUser();
588+
if (user == storeInst)
589+
continue; // Skip the store itself
613590

614-
// Find which argument position this variable is used in
615-
UInt argIndex = 0;
616-
bool foundArg = false;
617-
for (UInt i = 0; i < call->getArgCount(); i++)
618-
{
619-
if (call->getArg(i) == destPtr)
591+
// Allow loads - they are safe since we're not changing the data
592+
if (as<IRLoad>(user))
593+
continue; // Check the next use
594+
595+
// For function calls, check if the parameter is ConstRef
596+
if (auto call = as<IRCall>(user))
620597
{
621-
argIndex = i;
622-
foundArg = true;
623-
break;
598+
auto callee = call->getCallee();
599+
auto funcInst = as<IRFunc>(callee);
600+
if (!funcInst)
601+
goto unsafeToOptimize;
602+
603+
// Build parameter list once for efficient indexing
604+
List<IRParam*> params;
605+
for (auto param = funcInst->getFirstParam(); param;
606+
param = param->getNextParam())
607+
{
608+
params.add(param);
609+
}
610+
611+
// Check ALL argument positions where this variable is used
612+
for (UInt i = 0; i < call->getArgCount(); i++)
613+
{
614+
if (call->getArg(i) == destPtr)
615+
{
616+
if (i >= (UInt)params.getCount())
617+
goto unsafeToOptimize;
618+
619+
// Check if this parameter position is ConstRef
620+
auto param = params[i];
621+
if (!as<IRConstRefType>(param->getDataType()))
622+
goto unsafeToOptimize;
623+
}
624+
}
625+
continue; // Safe so far and check the next use
624626
}
625-
}
626627

627-
if (!foundArg)
628-
{
629-
isSafeToOptimize = false;
630-
break;
631-
}
628+
// TODO: there might be more cases that is safe to optimize
629+
// We need to add more cases here as needed.
632630

633-
// Check if the corresponding parameter is ConstRef
634-
auto param = funcInst->getFirstParam();
635-
for (UInt i = 0; i < argIndex && param; i++)
636-
{
637-
param = param->getNextParam();
631+
goto unsafeToOptimize;
638632
}
639-
if (!param || !as<IRConstRefType>(param->getDataType()))
640-
{
641-
// Unsafe, because the parameter might be used as mutable.
642-
isSafeToOptimize = false;
643-
break;
644-
}
645-
}
646633

647-
if (!isSafeToOptimize)
648-
continue;
634+
// If we get here, all uses are safe to optimize
635+
// safeToOptimize:
649636

650-
// Replace all uses of destPtr with loadedPtr
651-
destPtr->replaceUsesWith(loadedPtr);
637+
// Replace all uses of destPtr with loadedPtr
638+
destPtr->replaceUsesWith(loadedPtr);
652639

653-
// Mark both instructions for removal
654-
toRemove.add(storeInst);
655-
toRemove.add(destPtr);
640+
// Mark instructions for removal
641+
toRemove.add(storeInst);
642+
toRemove.add(destPtr);
656643

657-
// Note: loadInst might be still in use.
658-
// We need to rely on DCE to delete it if unused.
644+
// Note: loadInst might be still in use.
645+
// We need to rely on DCE to delete it if unused.
659646

660-
changed = true;
647+
changed = true;
648+
overallChanged = true;
649+
650+
unsafeToOptimize:;
651+
}
661652
}
662-
}
663653

664-
// Remove marked instructions
665-
for (auto instToRemove : toRemove)
666-
{
667-
instToRemove->removeAndDeallocate();
654+
// Remove marked instructions
655+
for (auto instToRemove : toRemove)
656+
{
657+
instToRemove->removeAndDeallocate();
658+
}
668659
}
669660

670-
return changed;
661+
return overallChanged;
671662
}
672663

673-
bool eliminateLoadStorePairs(IRModule* module)
664+
bool eliminateRedundantTemporaryCopy(IRModule* module)
674665
{
675-
bool changed = false;
676-
HashSet<IRFunc*> processedFuncs;
666+
bool overallChanged = false;
667+
668+
// Populate work list with all functions and any functions they call
669+
HashSet<IRFunc*> workListSet;
670+
List<IRFunc*> workList;
677671

678-
// Look for patterns: load(ptr) followed by store(var, load_result)
679-
// This can be optimized to direct pointer usage when safe
680-
// Process recursively through function calls
672+
// Start with all global functions
681673
for (auto inst : module->getGlobalInsts())
682674
{
683675
auto func = as<IRFunc>(inst);
684676
if (!func)
685677
continue;
678+
if (workListSet.add(func))
679+
workList.add(func);
680+
}
681+
682+
// Recursively add called functions
683+
for (Index i = 0; i < workList.getCount(); i++)
684+
{
685+
auto func = workList[i];
686+
for (auto block : func->getBlocks())
687+
{
688+
for (auto inst : block->getChildren())
689+
{
690+
if (auto call = as<IRCall>(inst))
691+
{
692+
if (auto calledFunc = as<IRFunc>(call->getCallee()))
693+
{
694+
if (workListSet.add(calledFunc))
695+
workList.add(calledFunc);
696+
}
697+
}
698+
}
699+
}
700+
}
686701

687-
changed |= eliminateLoadStorePairsInFunc(func, processedFuncs);
702+
for (auto func : workList)
703+
{
704+
overallChanged |= eliminateRedundantTemporaryCopyInFunc(func);
688705
}
689706

690-
return changed;
707+
return overallChanged;
691708
}
692709

693710
bool shouldInstBeLiveIfParentIsLive(IRInst* inst, IRDeadCodeEliminationOptions options)

source/slang/slang-ir-dce.h

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -36,12 +36,13 @@ bool isWeakReferenceOperand(IRInst* inst, UInt operandIndex);
3636

3737
bool trimOptimizableTypes(IRModule* module);
3838

39-
/// Eliminate unnecessary load+store pairs when safe to do so.
40-
/// This optimization looks for patterns where a value is loaded from a ConstRef
41-
/// parameter and immediately stored to a temporary variable, then only passed
42-
/// to functions that accept ConstRef parameters. In such cases, the temporary
43-
/// variable can be eliminated and the original ConstRef parameter used directly.
39+
/// Eliminate redundant temporary variable copies in load-store patterns.
40+
/// This optimization looks for patterns where a value is loaded from memory
41+
/// and immediately stored to a temporary variable, which is then only used
42+
/// in read-only contexts. In such cases, the temporary variable and the
43+
/// load-store indirection can be eliminated by using the original memory
44+
/// location directly.
4445
/// Returns true if any changes were made.
45-
bool eliminateLoadStorePairs(IRModule* module);
46+
bool eliminateRedundantTemporaryCopy(IRModule* module);
4647

4748
} // namespace Slang

source/slang/slang-ir-ssa-simplification.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ void simplifyIR(
7474
changed |= applySparseConditionalConstantPropagationForGlobalScope(module, sink);
7575
changed |= peepholeOptimizeGlobalScope(target, module);
7676
changed |= trimOptimizableTypes(module);
77-
changed |= eliminateLoadStorePairs(module);
77+
changed |= eliminateRedundantTemporaryCopy(module);
7878

7979
for (auto inst : module->getGlobalInsts())
8080
{

tests/metal/vector-get-element-ptr.slang

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,11 @@ void modify(inout int v)
1414
void computeMain(int3 v : SV_DispatchThreadID)
1515
{
1616
int3 u = v;
17-
// CHECK: int [[TEMP:[a-zA-Z0-9_]+]] = u{{.*}}.x;
17+
// CHECK: int [[TEMP:[a-zA-Z0-9_]+]] = [[OUT:[a-zA-Z0-9_]+]].x;
1818
// CHECK: modify{{.*}}(&[[TEMP]])
19-
// CHECK: u{{.*}}.x = [[TEMP]];
19+
// CHECK: [[OUT]].x = [[TEMP]];
2020

2121
modify(u.x);
2222
// BUF: 2
2323
outputBuffer[0] = u.x + u.y;
24-
}
24+
}

0 commit comments

Comments
 (0)