@@ -512,182 +512,199 @@ bool trimOptimizableTypes(IRModule* module)
512
512
return changed;
513
513
}
514
514
515
- static bool eliminateLoadStorePairsInFunc (IRFunc* func, HashSet<IRFunc*>& processedFuncs )
515
+ static bool eliminateRedundantTemporaryCopyInFunc (IRFunc* func)
516
516
{
517
- // Avoid infinite recursion by tracking processed functions
518
- if (processedFuncs.contains (func))
519
- return false ;
520
-
521
- processedFuncs.add (func);
522
-
523
- bool changed = false ;
524
- HashSet<IRInst*> toRemove;
525
-
526
- for (auto block : func->getBlocks ())
517
+ // Consider the following IR pattern:
518
+ // ```
519
+ // let %temp = var
520
+ // let %value = load(%sourcePtr)
521
+ // store(%temp, %value)
522
+ // ```
523
+ // We can replace "%temp" with "%sourcePtr" without the load and store indirection
524
+ // if "%temp" is used only in read-only contexts.
525
+
526
+ bool overallChanged = false ;
527
+ for (bool changed = true ; changed;)
527
528
{
528
- for (auto blockInst : block->getChildren ())
529
- {
530
- // First, recursively process any function calls in this block
531
- if (auto call = as<IRCall>(blockInst))
532
- {
533
- auto callee = call->getCallee ();
534
- if (auto calleeFunc = as<IRFunc>(callee))
535
- {
536
- changed |= eliminateLoadStorePairsInFunc (calleeFunc, processedFuncs);
537
- }
538
- }
539
-
540
- auto storeInst = as<IRStore>(blockInst);
541
- if (!storeInst)
542
- continue ;
529
+ changed = false ;
543
530
544
- auto storedValue = storeInst->getVal ();
545
- auto destPtr = storeInst->getPtr ();
546
-
547
- // Only optimize if destPtr is a variable (kIROp_Var).
548
- // Don't optimize stores to buffer elements or other meaningful destinations.
549
- if (!as<IRVar>(destPtr))
550
- continue ;
551
-
552
- // Check if we're storing a load result
553
- auto loadInst = as<IRLoad>(storedValue);
554
- if (!loadInst)
555
- continue ;
531
+ HashSet<IRInst*> toRemove;
556
532
557
- // Do not optimize the primitive types because the legalization step may assume the
558
- // existance of Load instructions when the entry parameters are replaced with the
559
- // builtin variables.
560
- auto loadInstType = loadInst->getDataType ();
561
- switch (loadInstType->getOp ())
533
+ for (auto block : func->getBlocks ())
534
+ {
535
+ for (auto blockInst : block->getChildren ())
562
536
{
563
- case kIROp_Int8Type :
564
- case kIROp_Int16Type :
565
- case kIROp_IntType :
566
- case kIROp_Int64Type :
567
- case kIROp_UInt8Type :
568
- case kIROp_UInt16Type :
569
- case kIROp_UIntType :
570
- case kIROp_UInt64Type :
571
- case kIROp_IntPtrType :
572
- case kIROp_UIntPtrType :
573
- case kIROp_FloatType :
574
- case kIROp_DoubleType :
575
- case kIROp_HalfType :
576
- case kIROp_BoolType :
577
- continue ;
578
- }
579
-
580
- auto loadedPtr = loadInst->getPtr ();
537
+ auto storeInst = as<IRStore>(blockInst);
538
+ if (!storeInst)
539
+ continue ;
581
540
582
- // Optimize only when the type is immutable (ConstRef).
583
- auto loadedPtrType = loadedPtr->getDataType ();
584
- if (!as<IRConstRefType>(loadedPtrType) && !as<IRParameterBlockType>(loadedPtrType) &&
585
- !as<IRConstantBufferType>(loadedPtrType))
586
- continue ;
541
+ auto storedValue = storeInst->getVal ();
542
+ auto destPtr = storeInst->getPtr ();
587
543
588
- // We cannot optimize if the variable is passed to a function call that treats the
589
- // parameter as mutable.
590
- bool isSafeToOptimize = true ;
544
+ // Only optimize temporary variable.
545
+ // Don't optimize stores to permanent memory locations.
546
+ if (destPtr->getOp () != kIROp_Var )
547
+ continue ;
591
548
592
- // Check if variable is only used in function calls with ConstRef parameters
593
- for (auto use = destPtr->firstUse ; use; use = use->nextUse )
594
- {
595
- auto user = use->getUser ();
596
- if (user == storeInst)
597
- continue ; // Skip the store itself
549
+ // Check if we're storing a load result
550
+ auto loadInst = as<IRLoad>(storedValue);
551
+ if (!loadInst)
552
+ continue ;
598
553
599
- auto call = as<IRCall>(user);
600
- if (!call)
554
+ auto loadedPtr = loadInst->getPtr ();
555
+
556
+ // Do not optimize loads from semantic parameters because some semantics have
557
+ // builtin types that are vector types but pretend to be scalar types (e.g.,
558
+ // SV_DispatchThreadID is used as 'int id' but maps to 'float3
559
+ // gl_GlobalInvocationID'). The legalization step must remove the load instruction
560
+ // to maintain this pretense, which breaks our load/store optimization assumptions.
561
+ // Skip optimization when loading from semantics to let legalization handle the load
562
+ // removal.
563
+ if (auto param = as<IRParam>(loadedPtr))
564
+ if (param->findDecoration <IRSemanticDecoration>())
565
+ continue ;
566
+
567
+ // We cannot optimize if the variable is passed to a function call that treats the
568
+ // parameter as mutable, or if there are multiple stores to the variable.
569
+
570
+ // If the pointer is re-used with IRStore more than once,
571
+ // handling the case will become more complex.
572
+ // TODO: we may want to handle the complex cases as well later.
573
+ UInt storeCount = 0 ;
574
+ for (auto use = destPtr->firstUse ; use; use = use->nextUse )
601
575
{
602
- isSafeToOptimize = false ;
603
- break ;
576
+ if (as<IRStore>(use->getUser ()))
577
+ {
578
+ storeCount++;
579
+ if (storeCount > 1 )
580
+ goto unsafeToOptimize;
581
+ }
604
582
}
605
583
606
- auto callee = call->getCallee ();
607
- auto funcInst = as<IRFunc>(callee);
608
- if (!funcInst)
584
+ // Check all uses of the destination variable
585
+ for (auto use = destPtr->firstUse ; use; use = use->nextUse )
609
586
{
610
- isSafeToOptimize = false ;
611
- break ;
612
- }
587
+ auto user = use-> getUser () ;
588
+ if (user == storeInst)
589
+ continue ; // Skip the store itself
613
590
614
- // Find which argument position this variable is used in
615
- UInt argIndex = 0 ;
616
- bool foundArg = false ;
617
- for (UInt i = 0 ; i < call-> getArgCount (); i++)
618
- {
619
- if (call-> getArg (i) == destPtr )
591
+ // Allow loads - they are safe since we're not changing the data
592
+ if (as<IRLoad>(user))
593
+ continue ; // Check the next use
594
+
595
+ // For function calls, check if the parameter is ConstRef
596
+ if (auto call = as<IRCall>(user) )
620
597
{
621
- argIndex = i;
622
- foundArg = true ;
623
- break ;
598
+ auto callee = call->getCallee ();
599
+ auto funcInst = as<IRFunc>(callee);
600
+ if (!funcInst)
601
+ goto unsafeToOptimize;
602
+
603
+ // Build parameter list once for efficient indexing
604
+ List<IRParam*> params;
605
+ for (auto param = funcInst->getFirstParam (); param;
606
+ param = param->getNextParam ())
607
+ {
608
+ params.add (param);
609
+ }
610
+
611
+ // Check ALL argument positions where this variable is used
612
+ for (UInt i = 0 ; i < call->getArgCount (); i++)
613
+ {
614
+ if (call->getArg (i) == destPtr)
615
+ {
616
+ if (i >= (UInt)params.getCount ())
617
+ goto unsafeToOptimize;
618
+
619
+ // Check if this parameter position is ConstRef
620
+ auto param = params[i];
621
+ if (!as<IRConstRefType>(param->getDataType ()))
622
+ goto unsafeToOptimize;
623
+ }
624
+ }
625
+ continue ; // Safe so far and check the next use
624
626
}
625
- }
626
627
627
- if (!foundArg)
628
- {
629
- isSafeToOptimize = false ;
630
- break ;
631
- }
628
+ // TODO: there might be more cases that is safe to optimize
629
+ // We need to add more cases here as needed.
632
630
633
- // Check if the corresponding parameter is ConstRef
634
- auto param = funcInst->getFirstParam ();
635
- for (UInt i = 0 ; i < argIndex && param; i++)
636
- {
637
- param = param->getNextParam ();
638
- }
639
- if (!param || !as<IRConstRefType>(param->getDataType ()))
640
- {
641
- // Unsafe, because the parameter might be used as mutable.
642
- isSafeToOptimize = false ;
643
- break ;
631
+ goto unsafeToOptimize;
644
632
}
645
- }
646
633
647
- if (!isSafeToOptimize)
648
- continue ;
634
+ // If we get here, all uses are safe to optimize
635
+ // safeToOptimize:
649
636
650
- // Replace all uses of destPtr with loadedPtr
651
- destPtr->replaceUsesWith (loadedPtr);
637
+ // Replace all uses of destPtr with loadedPtr
638
+ destPtr->replaceUsesWith (loadedPtr);
652
639
653
- // Mark both instructions for removal
654
- toRemove.add (storeInst);
655
- toRemove.add (destPtr);
640
+ // Mark instructions for removal
641
+ toRemove.add (storeInst);
642
+ toRemove.add (destPtr);
656
643
657
- // Note: loadInst might be still in use.
658
- // We need to rely on DCE to delete it if unused.
644
+ // Note: loadInst might be still in use.
645
+ // We need to rely on DCE to delete it if unused.
659
646
660
- changed = true ;
647
+ changed = true ;
648
+ overallChanged = true ;
649
+
650
+ unsafeToOptimize:;
651
+ }
661
652
}
662
- }
663
653
664
- // Remove marked instructions
665
- for (auto instToRemove : toRemove)
666
- {
667
- instToRemove->removeAndDeallocate ();
654
+ // Remove marked instructions
655
+ for (auto instToRemove : toRemove)
656
+ {
657
+ instToRemove->removeAndDeallocate ();
658
+ }
668
659
}
669
660
670
- return changed ;
661
+ return overallChanged ;
671
662
}
672
663
673
- bool eliminateLoadStorePairs (IRModule* module )
664
+ bool eliminateRedundantTemporaryCopy (IRModule* module )
674
665
{
675
- bool changed = false ;
676
- HashSet<IRFunc*> processedFuncs;
666
+ bool overallChanged = false ;
677
667
678
- // Look for patterns: load(ptr) followed by store(var, load_result)
679
- // This can be optimized to direct pointer usage when safe
680
- // Process recursively through function calls
668
+ // Populate work list with all functions and any functions they call
669
+ HashSet<IRFunc*> workListSet;
670
+ List<IRFunc*> workList;
671
+
672
+ // Start with all global functions
681
673
for (auto inst : module ->getGlobalInsts ())
682
674
{
683
675
auto func = as<IRFunc>(inst);
684
676
if (!func)
685
677
continue ;
678
+ if (workListSet.add (func))
679
+ workList.add (func);
680
+ }
686
681
687
- changed |= eliminateLoadStorePairsInFunc (func, processedFuncs);
682
+ // Recursively add called functions
683
+ for (Index i = 0 ; i < workList.getCount (); i++)
684
+ {
685
+ auto func = workList[i];
686
+ for (auto block : func->getBlocks ())
687
+ {
688
+ for (auto inst : block->getChildren ())
689
+ {
690
+ if (auto call = as<IRCall>(inst))
691
+ {
692
+ if (auto calledFunc = as<IRFunc>(call->getCallee ()))
693
+ {
694
+ if (workListSet.add (calledFunc))
695
+ workList.add (calledFunc);
696
+ }
697
+ }
698
+ }
699
+ }
688
700
}
689
701
690
- return changed;
702
+ for (auto func : workList)
703
+ {
704
+ overallChanged |= eliminateRedundantTemporaryCopyInFunc (func);
705
+ }
706
+
707
+ return overallChanged;
691
708
}
692
709
693
710
bool shouldInstBeLiveIfParentIsLive (IRInst* inst, IRDeadCodeEliminationOptions options)
0 commit comments