@@ -512,6 +512,180 @@ bool trimOptimizableTypes(IRModule* module)
512
512
return changed;
513
513
}
514
514
515
+ static bool eliminateLoadStorePairsInFunc (IRFunc* func, HashSet<IRFunc*>& processedFuncs)
516
+ {
517
+ // Avoid infinite recursion by tracking processed functions
518
+ if (processedFuncs.contains (func))
519
+ return false ;
520
+
521
+ processedFuncs.add (func);
522
+
523
+ bool changed = false ;
524
+ List<IRInst*> toRemove;
525
+
526
+ for (auto block : func->getBlocks ())
527
+ {
528
+ for (auto blockInst : block->getChildren ())
529
+ {
530
+ // First, recursively process any function calls in this block
531
+ if (auto call = as<IRCall>(blockInst))
532
+ {
533
+ auto callee = call->getCallee ();
534
+ if (auto calleeFunc = as<IRFunc>(callee))
535
+ {
536
+ changed |= eliminateLoadStorePairsInFunc (calleeFunc, processedFuncs);
537
+ }
538
+ }
539
+
540
+ auto storeInst = as<IRStore>(blockInst);
541
+ if (!storeInst)
542
+ continue ;
543
+
544
+ auto storedValue = storeInst->getVal ();
545
+ auto destPtr = storeInst->getPtr ();
546
+
547
+ // Only optimize if destPtr is a variable (kIROp_Var).
548
+ // Don't optimize stores to buffer elements or other meaningful destinations.
549
+ if (!as<IRVar>(destPtr))
550
+ continue ;
551
+
552
+ // Check if we're storing a load result
553
+ auto loadInst = as<IRLoad>(storedValue);
554
+ if (!loadInst)
555
+ continue ;
556
+
557
+ // Do not optimize the primitive types because the legalization step may assume the
558
+ // existance of Load instructions when the entry parameters are replaced with the
559
+ // builtin variables.
560
+ auto loadInstType = loadInst->getDataType ();
561
+ switch (loadInstType->getOp ())
562
+ {
563
+ case kIROp_Int8Type :
564
+ case kIROp_Int16Type :
565
+ case kIROp_IntType :
566
+ case kIROp_Int64Type :
567
+ case kIROp_UInt8Type :
568
+ case kIROp_UInt16Type :
569
+ case kIROp_UIntType :
570
+ case kIROp_UInt64Type :
571
+ case kIROp_IntPtrType :
572
+ case kIROp_UIntPtrType :
573
+ case kIROp_FloatType :
574
+ case kIROp_DoubleType :
575
+ case kIROp_HalfType :
576
+ case kIROp_BoolType :
577
+ continue ;
578
+ }
579
+
580
+ auto loadedPtr = loadInst->getPtr ();
581
+
582
+ // Optimize only when the type is immutable (ConstRef).
583
+ auto loadedPtrType = loadedPtr->getDataType ();
584
+ if (!as<IRConstRefType>(loadedPtrType) && !as<IRParameterBlockType>(loadedPtrType) &&
585
+ !as<IRConstantBufferType>(loadedPtrType))
586
+ continue ;
587
+
588
+ // We cannot optimize if the variable is passed to a function call that treats the
589
+ // parameter as mutable.
590
+ bool isSafeToOptimize = true ;
591
+
592
+ // Check if variable is only used in function calls with ConstRef parameters
593
+ for (auto use = destPtr->firstUse ; use; use = use->nextUse )
594
+ {
595
+ auto user = use->getUser ();
596
+ if (user == storeInst) continue ; // Skip the store itself
597
+
598
+ auto call = as<IRCall>(user);
599
+ if (!call)
600
+ {
601
+ isSafeToOptimize = false ;
602
+ break ;
603
+ }
604
+
605
+ auto callee = call->getCallee ();
606
+ auto funcInst = as<IRFunc>(callee);
607
+ if (!funcInst)
608
+ {
609
+ isSafeToOptimize = false ;
610
+ break ;
611
+ }
612
+
613
+ // Find which argument position this variable is used in
614
+ UInt argIndex = 0 ;
615
+ bool foundArg = false ;
616
+ for (UInt i = 0 ; i < call->getArgCount (); i++)
617
+ {
618
+ if (call->getArg (i) == destPtr)
619
+ {
620
+ argIndex = i;
621
+ foundArg = true ;
622
+ break ;
623
+ }
624
+ }
625
+
626
+ if (!foundArg)
627
+ {
628
+ isSafeToOptimize = false ;
629
+ break ;
630
+ }
631
+
632
+ // Check if the corresponding parameter is ConstRef
633
+ auto param = funcInst->getFirstParam ();
634
+ for (UInt i = 0 ; i < argIndex && param; i++)
635
+ {
636
+ param = param->getNextParam ();
637
+ }
638
+ if (!param || !as<IRConstRefType>(param->getDataType ()))
639
+ {
640
+ // Unsafe, because the parameter might be used as mutable.
641
+ isSafeToOptimize = false ;
642
+ break ;
643
+ }
644
+ }
645
+
646
+ if (!isSafeToOptimize)
647
+ continue ;
648
+
649
+ // Replace all uses of destPtr with loadedPtr
650
+ destPtr->replaceUsesWith (loadedPtr);
651
+
652
+ // Mark both instructions for removal
653
+ toRemove.add (storeInst);
654
+ toRemove.add (loadInst);
655
+ toRemove.add (destPtr);
656
+ changed = true ;
657
+ }
658
+ }
659
+
660
+ // Remove marked instructions
661
+ for (auto instToRemove : toRemove)
662
+ {
663
+ instToRemove->removeAndDeallocate ();
664
+ }
665
+
666
+ return changed;
667
+ }
668
+
669
+ bool eliminateLoadStorePairs (IRModule* module )
670
+ {
671
+ bool changed = false ;
672
+ HashSet<IRFunc*> processedFuncs;
673
+
674
+ // Look for patterns: load(ptr) followed by store(var, load_result)
675
+ // This can be optimized to direct pointer usage when safe
676
+ // Process recursively through function calls
677
+ for (auto inst : module ->getGlobalInsts ())
678
+ {
679
+ auto func = as<IRFunc>(inst);
680
+ if (!func)
681
+ continue ;
682
+
683
+ changed |= eliminateLoadStorePairsInFunc (func, processedFuncs);
684
+ }
685
+
686
+ return changed;
687
+ }
688
+
515
689
bool shouldInstBeLiveIfParentIsLive (IRInst* inst, IRDeadCodeEliminationOptions options)
516
690
{
517
691
// The main source of confusion/complexity here is that
0 commit comments