Skip to content

Commit 0fe6fd6

Browse files
committed
Refactor based on the feedback
1 parent 244e17e commit 0fe6fd6

File tree

4 files changed

+163
-146
lines changed

4 files changed

+163
-146
lines changed

source/slang/slang-ir-dce.cpp

Lines changed: 152 additions & 136 deletions
Original file line numberDiff line numberDiff line change
@@ -512,182 +512,198 @@ bool trimOptimizableTypes(IRModule* module)
512512
return changed;
513513
}
514514

515-
static bool eliminateLoadStorePairsInFunc(IRFunc* func, HashSet<IRFunc*>& processedFuncs)
515+
static bool eliminateRedundantTemporaryCopyInFunc(IRFunc* func)
516516
{
517-
// Avoid infinite recursion by tracking processed functions
518-
if (processedFuncs.contains(func))
519-
return false;
520-
521-
processedFuncs.add(func);
522-
523-
bool changed = false;
524-
HashSet<IRInst*> toRemove;
525-
526-
for (auto block : func->getBlocks())
517+
// Consider the following IR pattern:
518+
// ```
519+
// let %temp = var
520+
// let %value = load(%sourcePtr)
521+
// store(%temp, %value)
522+
// ```
523+
// We can replace "%temp" with "%sourcePtr" without the load and store indirection
524+
// if "%temp" is used only in read-only contexts.
525+
526+
bool overallChanged = false;
527+
for (bool changed = true; changed; )
527528
{
528-
for (auto blockInst : block->getChildren())
529-
{
530-
// First, recursively process any function calls in this block
531-
if (auto call = as<IRCall>(blockInst))
532-
{
533-
auto callee = call->getCallee();
534-
if (auto calleeFunc = as<IRFunc>(callee))
535-
{
536-
changed |= eliminateLoadStorePairsInFunc(calleeFunc, processedFuncs);
537-
}
538-
}
539-
540-
auto storeInst = as<IRStore>(blockInst);
541-
if (!storeInst)
542-
continue;
529+
changed = false;
543530

544-
auto storedValue = storeInst->getVal();
545-
auto destPtr = storeInst->getPtr();
546-
547-
// Only optimize if destPtr is a variable (kIROp_Var).
548-
// Don't optimize stores to buffer elements or other meaningful destinations.
549-
if (!as<IRVar>(destPtr))
550-
continue;
551-
552-
// Check if we're storing a load result
553-
auto loadInst = as<IRLoad>(storedValue);
554-
if (!loadInst)
555-
continue;
531+
HashSet<IRInst*> toRemove;
556532

557-
// Do not optimize the primitive types because the legalization step may assume the
558-
// existance of Load instructions when the entry parameters are replaced with the
559-
// builtin variables.
560-
auto loadInstType = loadInst->getDataType();
561-
switch (loadInstType->getOp())
533+
for (auto block : func->getBlocks())
534+
{
535+
for (auto blockInst : block->getChildren())
562536
{
563-
case kIROp_Int8Type:
564-
case kIROp_Int16Type:
565-
case kIROp_IntType:
566-
case kIROp_Int64Type:
567-
case kIROp_UInt8Type:
568-
case kIROp_UInt16Type:
569-
case kIROp_UIntType:
570-
case kIROp_UInt64Type:
571-
case kIROp_IntPtrType:
572-
case kIROp_UIntPtrType:
573-
case kIROp_FloatType:
574-
case kIROp_DoubleType:
575-
case kIROp_HalfType:
576-
case kIROp_BoolType:
577-
continue;
578-
}
579-
580-
auto loadedPtr = loadInst->getPtr();
537+
auto storeInst = as<IRStore>(blockInst);
538+
if (!storeInst)
539+
continue;
581540

582-
// Optimize only when the type is immutable (ConstRef).
583-
auto loadedPtrType = loadedPtr->getDataType();
584-
if (!as<IRConstRefType>(loadedPtrType) && !as<IRParameterBlockType>(loadedPtrType) &&
585-
!as<IRConstantBufferType>(loadedPtrType))
586-
continue;
541+
auto storedValue = storeInst->getVal();
542+
auto destPtr = storeInst->getPtr();
587543

588-
// We cannot optimize if the variable is passed to a function call that treats the
589-
// parameter as mutable.
590-
bool isSafeToOptimize = true;
544+
// Only optimize temporary variable.
545+
// Don't optimize stores to permanent memory locations.
546+
if (destPtr->getOp() != kIROp_Var)
547+
continue;
591548

592-
// Check if variable is only used in function calls with ConstRef parameters
593-
for (auto use = destPtr->firstUse; use; use = use->nextUse)
594-
{
595-
auto user = use->getUser();
596-
if (user == storeInst)
597-
continue; // Skip the store itself
549+
// Check if we're storing a load result
550+
auto loadInst = as<IRLoad>(storedValue);
551+
if (!loadInst)
552+
continue;
598553

599-
auto call = as<IRCall>(user);
600-
if (!call)
554+
auto loadedPtr = loadInst->getPtr();
555+
556+
// Do not optimize loads from semantic parameters because some semantics have
557+
// builtin types that are vector types but pretend to be scalar types (e.g.,
558+
// SV_DispatchThreadID is used as 'int id' but maps to 'float3 gl_GlobalInvocationID').
559+
// The legalization step must remove the load instruction to maintain this pretense,
560+
// which breaks our load/store optimization assumptions. Skip optimization when
561+
// loading from semantics to let legalization handle the load removal.
562+
if (auto param = as<IRParam>(loadedPtr))
563+
if (param->findDecoration<IRSemanticDecoration>())
564+
continue;
565+
566+
// We cannot optimize if the variable is passed to a function call that treats the
567+
// parameter as mutable, or if there are multiple stores to the variable.
568+
569+
// If the pointer is re-used with IRStore more than once,
570+
// handling the case will become more complex.
571+
// TODO: we may want to handle the complex cases as well later.
572+
UInt storeCount = 0;
573+
for (auto use = destPtr->firstUse; use; use = use->nextUse)
601574
{
602-
isSafeToOptimize = false;
603-
break;
575+
if (as<IRStore>(use->getUser()))
576+
{
577+
storeCount++;
578+
if (storeCount > 1)
579+
goto unsafeToOptimize;
580+
}
604581
}
605582

606-
auto callee = call->getCallee();
607-
auto funcInst = as<IRFunc>(callee);
608-
if (!funcInst)
583+
// Check all uses of the destination variable
584+
for (auto use = destPtr->firstUse; use; use = use->nextUse)
609585
{
610-
isSafeToOptimize = false;
611-
break;
612-
}
586+
auto user = use->getUser();
587+
if (user == storeInst)
588+
continue; // Skip the store itself
613589

614-
// Find which argument position this variable is used in
615-
UInt argIndex = 0;
616-
bool foundArg = false;
617-
for (UInt i = 0; i < call->getArgCount(); i++)
618-
{
619-
if (call->getArg(i) == destPtr)
590+
// Allow loads - they are safe since we're not changing the data
591+
if (as<IRLoad>(user))
592+
continue; // Check the next use
593+
594+
// For function calls, check if the parameter is ConstRef
595+
if (auto call = as<IRCall>(user))
620596
{
621-
argIndex = i;
622-
foundArg = true;
623-
break;
597+
auto callee = call->getCallee();
598+
auto funcInst = as<IRFunc>(callee);
599+
if (!funcInst)
600+
goto unsafeToOptimize;
601+
602+
// Build parameter list once for efficient indexing
603+
List<IRParam*> params;
604+
for (auto param = funcInst->getFirstParam(); param;
605+
param = param->getNextParam())
606+
{
607+
params.add(param);
608+
}
609+
610+
// Check ALL argument positions where this variable is used
611+
for (UInt i = 0; i < call->getArgCount(); i++)
612+
{
613+
if (call->getArg(i) == destPtr)
614+
{
615+
if (i >= (UInt)params.getCount())
616+
goto unsafeToOptimize;
617+
618+
// Check if this parameter position is ConstRef
619+
auto param = params[i];
620+
if (!as<IRConstRefType>(param->getDataType()))
621+
goto unsafeToOptimize;
622+
}
623+
}
624+
continue; // Safe so far and check the next use
624625
}
625-
}
626626

627-
if (!foundArg)
628-
{
629-
isSafeToOptimize = false;
630-
break;
631-
}
627+
// TODO: there might be more cases that is safe to optimize
628+
// We need to add more cases here as needed.
632629

633-
// Check if the corresponding parameter is ConstRef
634-
auto param = funcInst->getFirstParam();
635-
for (UInt i = 0; i < argIndex && param; i++)
636-
{
637-
param = param->getNextParam();
638-
}
639-
if (!param || !as<IRConstRefType>(param->getDataType()))
640-
{
641-
// Unsafe, because the parameter might be used as mutable.
642-
isSafeToOptimize = false;
643-
break;
630+
goto unsafeToOptimize;
644631
}
645-
}
646632

647-
if (!isSafeToOptimize)
648-
continue;
633+
// If we get here, all uses are safe to optimize
634+
// safeToOptimize:
649635

650-
// Replace all uses of destPtr with loadedPtr
651-
destPtr->replaceUsesWith(loadedPtr);
636+
// Replace all uses of destPtr with loadedPtr
637+
destPtr->replaceUsesWith(loadedPtr);
652638

653-
// Mark both instructions for removal
654-
toRemove.add(storeInst);
655-
toRemove.add(destPtr);
639+
// Mark instructions for removal
640+
toRemove.add(storeInst);
641+
toRemove.add(destPtr);
656642

657-
// Note: loadInst might be still in use.
658-
// We need to rely on DCE to delete it if unused.
643+
// Note: loadInst might be still in use.
644+
// We need to rely on DCE to delete it if unused.
659645

660-
changed = true;
646+
changed = true;
647+
overallChanged = true;
648+
649+
unsafeToOptimize:;
650+
}
661651
}
662-
}
663652

664-
// Remove marked instructions
665-
for (auto instToRemove : toRemove)
666-
{
667-
instToRemove->removeAndDeallocate();
653+
// Remove marked instructions
654+
for (auto instToRemove : toRemove)
655+
{
656+
instToRemove->removeAndDeallocate();
657+
}
668658
}
669659

670-
return changed;
660+
return overallChanged;
671661
}
672662

673-
bool eliminateLoadStorePairs(IRModule* module)
663+
bool eliminateRedundantTemporaryCopy(IRModule* module)
674664
{
675-
bool changed = false;
676-
HashSet<IRFunc*> processedFuncs;
665+
bool overallChanged = false;
677666

678-
// Look for patterns: load(ptr) followed by store(var, load_result)
679-
// This can be optimized to direct pointer usage when safe
680-
// Process recursively through function calls
667+
// Populate work list with all functions and any functions they call
668+
HashSet<IRFunc*> workListSet;
669+
List<IRFunc*> workList;
670+
671+
// Start with all global functions
681672
for (auto inst : module->getGlobalInsts())
682673
{
683674
auto func = as<IRFunc>(inst);
684675
if (!func)
685676
continue;
677+
if (workListSet.add(func))
678+
workList.add(func);
679+
}
686680

687-
changed |= eliminateLoadStorePairsInFunc(func, processedFuncs);
681+
// Recursively add called functions
682+
for (Index i = 0; i < workList.getCount(); i++)
683+
{
684+
auto func = workList[i];
685+
for (auto block : func->getBlocks())
686+
{
687+
for (auto inst : block->getChildren())
688+
{
689+
if (auto call = as<IRCall>(inst))
690+
{
691+
if (auto calledFunc = as<IRFunc>(call->getCallee()))
692+
{
693+
if (workListSet.add(calledFunc))
694+
workList.add(calledFunc);
695+
}
696+
}
697+
}
698+
}
688699
}
689700

690-
return changed;
701+
for (auto func : workList)
702+
{
703+
overallChanged |= eliminateRedundantTemporaryCopyInFunc(func);
704+
}
705+
706+
return overallChanged;
691707
}
692708

693709
bool shouldInstBeLiveIfParentIsLive(IRInst* inst, IRDeadCodeEliminationOptions options)

source/slang/slang-ir-dce.h

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -36,12 +36,13 @@ bool isWeakReferenceOperand(IRInst* inst, UInt operandIndex);
3636

3737
bool trimOptimizableTypes(IRModule* module);
3838

39-
/// Eliminate unnecessary load+store pairs when safe to do so.
40-
/// This optimization looks for patterns where a value is loaded from a ConstRef
41-
/// parameter and immediately stored to a temporary variable, then only passed
42-
/// to functions that accept ConstRef parameters. In such cases, the temporary
43-
/// variable can be eliminated and the original ConstRef parameter used directly.
39+
/// Eliminate redundant temporary variable copies in load-store patterns.
40+
/// This optimization looks for patterns where a value is loaded from memory
41+
/// and immediately stored to a temporary variable, which is then only used
42+
/// in read-only contexts. In such cases, the temporary variable and the
43+
/// load-store indirection can be eliminated by using the original memory
44+
/// location directly.
4445
/// Returns true if any changes were made.
45-
bool eliminateLoadStorePairs(IRModule* module);
46+
bool eliminateRedundantTemporaryCopy(IRModule* module);
4647

4748
} // namespace Slang

source/slang/slang-ir-ssa-simplification.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ void simplifyIR(
7474
changed |= applySparseConditionalConstantPropagationForGlobalScope(module, sink);
7575
changed |= peepholeOptimizeGlobalScope(target, module);
7676
changed |= trimOptimizableTypes(module);
77-
changed |= eliminateLoadStorePairs(module);
77+
changed |= eliminateRedundantTemporaryCopy(module);
7878

7979
for (auto inst : module->getGlobalInsts())
8080
{

tests/metal/vector-get-element-ptr.slang

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,11 @@ void modify(inout int v)
1414
void computeMain(int3 v : SV_DispatchThreadID)
1515
{
1616
int3 u = v;
17-
// CHECK: int [[TEMP:[a-zA-Z0-9_]+]] = u{{.*}}.x;
17+
// CHECK: int [[TEMP:[a-zA-Z0-9_]+]] = [[OUT:[a-zA-Z0-9_]+]].x;
1818
// CHECK: modify{{.*}}(&[[TEMP]])
19-
// CHECK: u{{.*}}.x = [[TEMP]];
19+
// CHECK: [[OUT]].x = [[TEMP]];
2020

2121
modify(u.x);
2222
// BUF: 2
2323
outputBuffer[0] = u.x + u.y;
24-
}
24+
}

0 commit comments

Comments
 (0)