@@ -1098,29 +1098,80 @@ bool RewriteScheduleStage::initGCNSchedStage() {
1098
1098
}
1099
1099
}
1100
1100
1101
- bool ShouldRewrite = false ;
1101
+ unsigned ArchVGPRThreshold =
1102
+ ST.getMaxNumVectorRegs (DAG.MF .getFunction ()).first ;
1103
+
1104
+ int64_t Cost = 0 ;
1105
+ MBFI.calculate (MF, MBPI, *DAG.MLI );
1102
1106
for (unsigned RegionIdx = 0 ; RegionIdx < DAG.Regions .size (); RegionIdx++) {
1103
1107
if (!DAG.RegionsWithExcessArchVGPR [RegionIdx])
1104
1108
continue ;
1105
1109
1110
+ unsigned MaxCombinedVGPRs = ST.getMaxNumVGPRs (MF);
1111
+
1112
+ auto PressureBefore = DAG.Pressure [RegionIdx];
1113
+ unsigned UnifiedPressureBefore =
1114
+ PressureBefore.getVGPRNum (true , ArchVGPRThreshold);
1115
+ unsigned ArchPressureBefore =
1116
+ PressureBefore.getArchVGPRNum (ArchVGPRThreshold);
1117
+ unsigned AGPRPressureBefore = PressureBefore.getAGPRNum (ArchVGPRThreshold);
1118
+ unsigned UnifiedSpillBefore =
1119
+ UnifiedPressureBefore > MaxCombinedVGPRs
1120
+ ? (UnifiedPressureBefore - MaxCombinedVGPRs)
1121
+ : 0 ;
1122
+ unsigned ArchSpillBefore =
1123
+ ArchPressureBefore > ST.getAddressableNumArchVGPRs ()
1124
+ ? (ArchPressureBefore - ST.getAddressableNumArchVGPRs ())
1125
+ : 0 ;
1126
+ unsigned AGPRSpillBefore =
1127
+ AGPRPressureBefore > ST.getAddressableNumArchVGPRs ()
1128
+ ? (AGPRPressureBefore - ST.getAddressableNumArchVGPRs ())
1129
+ : 0 ;
1130
+
1131
+ unsigned SpillCostBefore =
1132
+ std::max (UnifiedSpillBefore, (ArchSpillBefore + AGPRSpillBefore));
1133
+
1134
+
1106
1135
// For the cases we care about (i.e. ArchVGPR usage is greater than the
1107
1136
// addressable limit), rewriting alone should bring pressure to manageable
1108
1137
// level. If we find any such region, then the rewrite is potentially
1109
1138
// beneficial.
1110
1139
auto PressureAfter = DAG.getRealRegPressure (RegionIdx);
1111
- unsigned MaxCombinedVGPRs = ST.getMaxNumVGPRs (MF);
1112
- if (PressureAfter.getArchVGPRNum () <= ST.getAddressableNumArchVGPRs () &&
1113
- PressureAfter.getVGPRNum (true ) <= MaxCombinedVGPRs) {
1114
- ShouldRewrite = true ;
1115
- break ;
1116
- }
1140
+ unsigned UnifiedPressureAfter =
1141
+ PressureAfter.getVGPRNum (true , ArchVGPRThreshold);
1142
+ unsigned ArchPressureAfter =
1143
+ PressureAfter.getArchVGPRNum (ArchVGPRThreshold);
1144
+ unsigned AGPRPressureAfter = PressureAfter.getAGPRNum (ArchVGPRThreshold);
1145
+ unsigned UnifiedSpillAfter = UnifiedPressureAfter > MaxCombinedVGPRs
1146
+ ? (UnifiedPressureAfter - MaxCombinedVGPRs)
1147
+ : 0 ;
1148
+ unsigned ArchSpillAfter =
1149
+ ArchPressureAfter > ST.getAddressableNumArchVGPRs ()
1150
+ ? (ArchPressureAfter - ST.getAddressableNumArchVGPRs ())
1151
+ : 0 ;
1152
+ unsigned AGPRSpillAfter =
1153
+ AGPRPressureAfter > ST.getAddressableNumArchVGPRs ()
1154
+ ? (AGPRPressureAfter - ST.getAddressableNumArchVGPRs ())
1155
+ : 0 ;
1156
+
1157
+ unsigned SpillCostAfter =
1158
+ std::max (UnifiedSpillAfter, (ArchSpillAfter + AGPRSpillAfter));
1159
+
1160
+ uint64_t EntryFreq = MBFI.getEntryFreq ().getFrequency ();
1161
+ uint64_t BlockFreq =
1162
+ EntryFreq ? MBFI.getBlockFreq (DAG.Regions [RegionIdx].first ->getParent ())
1163
+ .getFrequency () / EntryFreq
1164
+ : 1 ;
1165
+
1166
+ // Assumes perfect spilling -- giving edge to VGPR form.
1167
+ Cost += ((int )SpillCostAfter - (int )SpillCostBefore) * (int )BlockFreq;
1117
1168
}
1118
1169
1119
1170
// If we find that we'll need to insert cross RC copies inside loop bodies,
1120
1171
// then bail
1172
+ bool ShouldRewrite = Cost < 0 ;
1121
1173
if (ShouldRewrite) {
1122
- CI.clear ();
1123
- CI.compute (MF);
1174
+ uint64_t EntryFreq = MBFI.getEntryFreq ().getFrequency ();
1124
1175
1125
1176
for (auto *DefMI : CrossRCUseCopies) {
1126
1177
auto DefReg = DefMI->getOperand (0 ).getReg ();
@@ -1137,11 +1188,16 @@ bool RewriteScheduleStage::initGCNSchedStage() {
1137
1188
if (!RequiredRC || SRI->hasAGPRs (RequiredRC))
1138
1189
continue ;
1139
1190
1140
- unsigned DefDepth = CI.getCycleDepth (DefMI->getParent ());
1141
- if (DefDepth && CI.getCycleDepth (UseMI.getParent ()) >= DefDepth) {
1142
- ShouldRewrite = false ;
1191
+ uint64_t UseFreq =
1192
+ EntryFreq ? MBFI.getBlockFreq (UseMI.getParent ()).getFrequency () /
1193
+ EntryFreq
1194
+ : 1 ;
1195
+
1196
+ // Assumes no copy-reuse, giving edge to VGPR form.
1197
+ Cost += UseFreq;
1198
+ ShouldRewrite = Cost < 0 ;
1199
+ if (!ShouldRewrite)
1143
1200
break ;
1144
- }
1145
1201
}
1146
1202
if (!ShouldRewrite)
1147
1203
break ;
@@ -1596,7 +1652,8 @@ void GCNSchedStage::checkScheduling() {
1596
1652
DAG.RegionsWithExcessRP [RegionIdx] = true ;
1597
1653
}
1598
1654
1599
- if (PressureAfter.getArchVGPRNum () > ST.getAddressableNumArchVGPRs ())
1655
+ if (PressureAfter.getArchVGPRNum (ArchVGPRThreshold) >
1656
+ ST.getAddressableNumArchVGPRs ())
1600
1657
DAG.RegionsWithExcessArchVGPR [RegionIdx] = true ;
1601
1658
1602
1659
// Revert if this region's schedule would cause a drop in occupancy or
0 commit comments