@@ -512,182 +512,198 @@ bool trimOptimizableTypes(IRModule* module)
512
512
return changed;
513
513
}
514
514
515
- static bool eliminateLoadStorePairsInFunc (IRFunc* func, HashSet<IRFunc*>& processedFuncs )
515
+ static bool eliminateRedundantTemporaryCopyInFunc (IRFunc* func)
516
516
{
517
- // Avoid infinite recursion by tracking processed functions
518
- if (processedFuncs.contains (func))
519
- return false ;
520
-
521
- processedFuncs.add (func);
522
-
523
- bool changed = false ;
524
- HashSet<IRInst*> toRemove;
525
-
526
- for (auto block : func->getBlocks ())
517
+ // Consider the following IR pattern:
518
+ // ```
519
+ // let %temp = var
520
+ // let %value = load(%sourcePtr)
521
+ // store(%temp, %value)
522
+ // ```
523
+ // We can replace "%temp" with "%sourcePtr" without the load and store indirection
524
+ // if "%temp" is used only in read-only contexts.
525
+
526
+ bool overallChanged = false ;
527
+ for (bool changed = true ; changed; )
527
528
{
528
- for (auto blockInst : block->getChildren ())
529
- {
530
- // First, recursively process any function calls in this block
531
- if (auto call = as<IRCall>(blockInst))
532
- {
533
- auto callee = call->getCallee ();
534
- if (auto calleeFunc = as<IRFunc>(callee))
535
- {
536
- changed |= eliminateLoadStorePairsInFunc (calleeFunc, processedFuncs);
537
- }
538
- }
539
-
540
- auto storeInst = as<IRStore>(blockInst);
541
- if (!storeInst)
542
- continue ;
529
+ changed = false ;
543
530
544
- auto storedValue = storeInst->getVal ();
545
- auto destPtr = storeInst->getPtr ();
546
-
547
- // Only optimize if destPtr is a variable (kIROp_Var).
548
- // Don't optimize stores to buffer elements or other meaningful destinations.
549
- if (!as<IRVar>(destPtr))
550
- continue ;
551
-
552
- // Check if we're storing a load result
553
- auto loadInst = as<IRLoad>(storedValue);
554
- if (!loadInst)
555
- continue ;
531
+ HashSet<IRInst*> toRemove;
556
532
557
- // Do not optimize the primitive types because the legalization step may assume the
558
- // existance of Load instructions when the entry parameters are replaced with the
559
- // builtin variables.
560
- auto loadInstType = loadInst->getDataType ();
561
- switch (loadInstType->getOp ())
533
+ for (auto block : func->getBlocks ())
534
+ {
535
+ for (auto blockInst : block->getChildren ())
562
536
{
563
- case kIROp_Int8Type :
564
- case kIROp_Int16Type :
565
- case kIROp_IntType :
566
- case kIROp_Int64Type :
567
- case kIROp_UInt8Type :
568
- case kIROp_UInt16Type :
569
- case kIROp_UIntType :
570
- case kIROp_UInt64Type :
571
- case kIROp_IntPtrType :
572
- case kIROp_UIntPtrType :
573
- case kIROp_FloatType :
574
- case kIROp_DoubleType :
575
- case kIROp_HalfType :
576
- case kIROp_BoolType :
577
- continue ;
578
- }
579
-
580
- auto loadedPtr = loadInst->getPtr ();
537
+ auto storeInst = as<IRStore>(blockInst);
538
+ if (!storeInst)
539
+ continue ;
581
540
582
- // Optimize only when the type is immutable (ConstRef).
583
- auto loadedPtrType = loadedPtr->getDataType ();
584
- if (!as<IRConstRefType>(loadedPtrType) && !as<IRParameterBlockType>(loadedPtrType) &&
585
- !as<IRConstantBufferType>(loadedPtrType))
586
- continue ;
541
+ auto storedValue = storeInst->getVal ();
542
+ auto destPtr = storeInst->getPtr ();
587
543
588
- // We cannot optimize if the variable is passed to a function call that treats the
589
- // parameter as mutable.
590
- bool isSafeToOptimize = true ;
544
+ // Only optimize temporary variable.
545
+ // Don't optimize stores to permanent memory locations.
546
+ if (destPtr->getOp () != kIROp_Var )
547
+ continue ;
591
548
592
- // Check if variable is only used in function calls with ConstRef parameters
593
- for (auto use = destPtr->firstUse ; use; use = use->nextUse )
594
- {
595
- auto user = use->getUser ();
596
- if (user == storeInst)
597
- continue ; // Skip the store itself
549
+ // Check if we're storing a load result
550
+ auto loadInst = as<IRLoad>(storedValue);
551
+ if (!loadInst)
552
+ continue ;
598
553
599
- auto call = as<IRCall>(user);
600
- if (!call)
554
+ auto loadedPtr = loadInst->getPtr ();
555
+
556
+ // Do not optimize loads from semantic parameters because some semantics have
557
+ // builtin types that are vector types but pretend to be scalar types (e.g.,
558
+ // SV_DispatchThreadID is used as 'int id' but maps to 'float3 gl_GlobalInvocationID').
559
+ // The legalization step must remove the load instruction to maintain this pretense,
560
+ // which breaks our load/store optimization assumptions. Skip optimization when
561
+ // loading from semantics to let legalization handle the load removal.
562
+ if (auto param = as<IRParam>(loadedPtr))
563
+ if (param->findDecoration <IRSemanticDecoration>())
564
+ continue ;
565
+
566
+ // We cannot optimize if the variable is passed to a function call that treats the
567
+ // parameter as mutable, or if there are multiple stores to the variable.
568
+
569
+ // If the pointer is re-used with IRStore more than once,
570
+ // handling the case will become more complex.
571
+ // TODO: we may want to handle the complex cases as well later.
572
+ UInt storeCount = 0 ;
573
+ for (auto use = destPtr->firstUse ; use; use = use->nextUse )
601
574
{
602
- isSafeToOptimize = false ;
603
- break ;
575
+ if (as<IRStore>(use->getUser ()))
576
+ {
577
+ storeCount++;
578
+ if (storeCount > 1 )
579
+ goto unsafeToOptimize;
580
+ }
604
581
}
605
582
606
- auto callee = call->getCallee ();
607
- auto funcInst = as<IRFunc>(callee);
608
- if (!funcInst)
583
+ // Check all uses of the destination variable
584
+ for (auto use = destPtr->firstUse ; use; use = use->nextUse )
609
585
{
610
- isSafeToOptimize = false ;
611
- break ;
612
- }
586
+ auto user = use-> getUser () ;
587
+ if (user == storeInst)
588
+ continue ; // Skip the store itself
613
589
614
- // Find which argument position this variable is used in
615
- UInt argIndex = 0 ;
616
- bool foundArg = false ;
617
- for (UInt i = 0 ; i < call-> getArgCount (); i++)
618
- {
619
- if (call-> getArg (i) == destPtr )
590
+ // Allow loads - they are safe since we're not changing the data
591
+ if (as<IRLoad>(user))
592
+ continue ; // Check the next use
593
+
594
+ // For function calls, check if the parameter is ConstRef
595
+ if (auto call = as<IRCall>(user) )
620
596
{
621
- argIndex = i;
622
- foundArg = true ;
623
- break ;
597
+ auto callee = call->getCallee ();
598
+ auto funcInst = as<IRFunc>(callee);
599
+ if (!funcInst)
600
+ goto unsafeToOptimize;
601
+
602
+ // Build parameter list once for efficient indexing
603
+ List<IRParam*> params;
604
+ for (auto param = funcInst->getFirstParam (); param;
605
+ param = param->getNextParam ())
606
+ {
607
+ params.add (param);
608
+ }
609
+
610
+ // Check ALL argument positions where this variable is used
611
+ for (UInt i = 0 ; i < call->getArgCount (); i++)
612
+ {
613
+ if (call->getArg (i) == destPtr)
614
+ {
615
+ if (i >= (UInt)params.getCount ())
616
+ goto unsafeToOptimize;
617
+
618
+ // Check if this parameter position is ConstRef
619
+ auto param = params[i];
620
+ if (!as<IRConstRefType>(param->getDataType ()))
621
+ goto unsafeToOptimize;
622
+ }
623
+ }
624
+ continue ; // Safe so far and check the next use
624
625
}
625
- }
626
626
627
- if (!foundArg)
628
- {
629
- isSafeToOptimize = false ;
630
- break ;
631
- }
627
+ // TODO: there might be more cases that is safe to optimize
628
+ // We need to add more cases here as needed.
632
629
633
- // Check if the corresponding parameter is ConstRef
634
- auto param = funcInst->getFirstParam ();
635
- for (UInt i = 0 ; i < argIndex && param; i++)
636
- {
637
- param = param->getNextParam ();
638
- }
639
- if (!param || !as<IRConstRefType>(param->getDataType ()))
640
- {
641
- // Unsafe, because the parameter might be used as mutable.
642
- isSafeToOptimize = false ;
643
- break ;
630
+ goto unsafeToOptimize;
644
631
}
645
- }
646
632
647
- if (!isSafeToOptimize)
648
- continue ;
633
+ // If we get here, all uses are safe to optimize
634
+ // safeToOptimize:
649
635
650
- // Replace all uses of destPtr with loadedPtr
651
- destPtr->replaceUsesWith (loadedPtr);
636
+ // Replace all uses of destPtr with loadedPtr
637
+ destPtr->replaceUsesWith (loadedPtr);
652
638
653
- // Mark both instructions for removal
654
- toRemove.add (storeInst);
655
- toRemove.add (destPtr);
639
+ // Mark instructions for removal
640
+ toRemove.add (storeInst);
641
+ toRemove.add (destPtr);
656
642
657
- // Note: loadInst might be still in use.
658
- // We need to rely on DCE to delete it if unused.
643
+ // Note: loadInst might be still in use.
644
+ // We need to rely on DCE to delete it if unused.
659
645
660
- changed = true ;
646
+ changed = true ;
647
+ overallChanged = true ;
648
+
649
+ unsafeToOptimize:;
650
+ }
661
651
}
662
- }
663
652
664
- // Remove marked instructions
665
- for (auto instToRemove : toRemove)
666
- {
667
- instToRemove->removeAndDeallocate ();
653
+ // Remove marked instructions
654
+ for (auto instToRemove : toRemove)
655
+ {
656
+ instToRemove->removeAndDeallocate ();
657
+ }
668
658
}
669
659
670
- return changed ;
660
+ return overallChanged ;
671
661
}
672
662
673
- bool eliminateLoadStorePairs (IRModule* module )
663
+ bool eliminateRedundantTemporaryCopy (IRModule* module )
674
664
{
675
- bool changed = false ;
676
- HashSet<IRFunc*> processedFuncs;
665
+ bool overallChanged = false ;
677
666
678
- // Look for patterns: load(ptr) followed by store(var, load_result)
679
- // This can be optimized to direct pointer usage when safe
680
- // Process recursively through function calls
667
+ // Populate work list with all functions and any functions they call
668
+ HashSet<IRFunc*> workListSet;
669
+ List<IRFunc*> workList;
670
+
671
+ // Start with all global functions
681
672
for (auto inst : module ->getGlobalInsts ())
682
673
{
683
674
auto func = as<IRFunc>(inst);
684
675
if (!func)
685
676
continue ;
677
+ if (workListSet.add (func))
678
+ workList.add (func);
679
+ }
686
680
687
- changed |= eliminateLoadStorePairsInFunc (func, processedFuncs);
681
+ // Recursively add called functions
682
+ for (Index i = 0 ; i < workList.getCount (); i++)
683
+ {
684
+ auto func = workList[i];
685
+ for (auto block : func->getBlocks ())
686
+ {
687
+ for (auto inst : block->getChildren ())
688
+ {
689
+ if (auto call = as<IRCall>(inst))
690
+ {
691
+ if (auto calledFunc = as<IRFunc>(call->getCallee ()))
692
+ {
693
+ if (workListSet.add (calledFunc))
694
+ workList.add (calledFunc);
695
+ }
696
+ }
697
+ }
698
+ }
688
699
}
689
700
690
- return changed;
701
+ for (auto func : workList)
702
+ {
703
+ overallChanged |= eliminateRedundantTemporaryCopyInFunc (func);
704
+ }
705
+
706
+ return overallChanged;
691
707
}
692
708
693
709
bool shouldInstBeLiveIfParentIsLive (IRInst* inst, IRDeadCodeEliminationOptions options)
0 commit comments