@@ -99,20 +99,22 @@ void GCNRegPressure::inc(unsigned Reg,
99
99
bool GCNRegPressure::less (const MachineFunction &MF, const GCNRegPressure &O,
100
100
unsigned MaxOccupancy) const {
101
101
const GCNSubtarget &ST = MF.getSubtarget <GCNSubtarget>();
102
+ unsigned ArchVGPRThreshold = ST.getMaxNumVectorRegs (MF.getFunction ()).first ;
102
103
unsigned DynamicVGPRBlockSize =
103
104
MF.getInfo <SIMachineFunctionInfo>()->getDynamicVGPRBlockSize ();
104
105
105
106
const auto SGPROcc = std::min (MaxOccupancy,
106
107
ST.getOccupancyWithNumSGPRs (getSGPRNum ()));
107
108
const auto VGPROcc = std::min (
108
- MaxOccupancy, ST.getOccupancyWithNumVGPRs (getVGPRNum (ST.hasGFX90AInsts ()),
109
- DynamicVGPRBlockSize));
109
+ MaxOccupancy, ST.getOccupancyWithNumVGPRs (
110
+ getVGPRNum (ST.hasGFX90AInsts (), ArchVGPRThreshold),
111
+ DynamicVGPRBlockSize));
110
112
const auto OtherSGPROcc = std::min (MaxOccupancy,
111
113
ST.getOccupancyWithNumSGPRs (O.getSGPRNum ()));
112
- const auto OtherVGPROcc =
113
- std::min ( MaxOccupancy,
114
- ST. getOccupancyWithNumVGPRs ( O.getVGPRNum (ST.hasGFX90AInsts ()),
115
- DynamicVGPRBlockSize));
114
+ const auto OtherVGPROcc = std::min (
115
+ MaxOccupancy, ST. getOccupancyWithNumVGPRs (
116
+ O.getVGPRNum (ST.hasGFX90AInsts (), ArchVGPRThreshold ),
117
+ DynamicVGPRBlockSize));
116
118
117
119
const auto Occ = std::min (SGPROcc, VGPROcc);
118
120
const auto OtherOcc = std::min (OtherSGPROcc, OtherVGPROcc);
@@ -135,35 +137,39 @@ bool GCNRegPressure::less(const MachineFunction &MF, const GCNRegPressure &O,
135
137
unsigned OtherVGPRForSGPRSpills =
136
138
(OtherExcessSGPR + (WaveSize - 1 )) / WaveSize;
137
139
138
- unsigned MaxArchVGPRs = ST.getAddressableNumArchVGPRs ();
139
-
140
140
// Unified excess pressure conditions, accounting for VGPRs used for SGPR
141
141
// spills
142
- unsigned ExcessVGPR =
143
- std::max ( static_cast <int >(getVGPRNum (ST.hasGFX90AInsts ()) +
144
- VGPRForSGPRSpills - MaxVGPRs),
145
- 0 );
146
- unsigned OtherExcessVGPR =
147
- std::max ( static_cast <int >(O.getVGPRNum (ST.hasGFX90AInsts ()) +
148
- OtherVGPRForSGPRSpills - MaxVGPRs),
149
- 0 );
142
+ unsigned ExcessVGPR = std::max (
143
+ static_cast <int >(getVGPRNum (ST.hasGFX90AInsts (), ArchVGPRThreshold ) +
144
+ VGPRForSGPRSpills - MaxVGPRs),
145
+ 0 );
146
+ unsigned OtherExcessVGPR = std::max (
147
+ static_cast <int >(O.getVGPRNum (ST.hasGFX90AInsts (), ArchVGPRThreshold ) +
148
+ OtherVGPRForSGPRSpills - MaxVGPRs),
149
+ 0 );
150
150
// Arch VGPR excess pressure conditions, accounting for VGPRs used for SGPR
151
151
// spills
152
- unsigned ExcessArchVGPR = std::max (
153
- static_cast <int >(getVGPRNum (false ) + VGPRForSGPRSpills - MaxArchVGPRs),
154
- 0 );
152
+ unsigned AddressableArchVGPRs = ST.getAddressableNumArchVGPRs ();
153
+ unsigned ExcessArchVGPR =
154
+ std::max (static_cast <int >(getVGPRNum (false , ArchVGPRThreshold) +
155
+ VGPRForSGPRSpills - AddressableArchVGPRs),
156
+ 0 );
155
157
unsigned OtherExcessArchVGPR =
156
- std::max (static_cast <int >(O.getVGPRNum (false ) + OtherVGPRForSGPRSpills -
157
- MaxArchVGPRs ),
158
+ std::max (static_cast <int >(O.getVGPRNum (false , ArchVGPRThreshold ) +
159
+ OtherVGPRForSGPRSpills - AddressableArchVGPRs ),
158
160
0 );
159
161
// AGPR excess pressure conditions
160
- unsigned ExcessAGPR = std::max (
161
- static_cast <int >(ST.hasGFX90AInsts () ? (getAGPRNum () - MaxArchVGPRs)
162
- : (getAGPRNum () - MaxVGPRs)),
163
- 0 );
162
+ unsigned ExcessAGPR =
163
+ std::max (static_cast <int >(
164
+ ST.hasGFX90AInsts ()
165
+ ? (getAGPRNum (ArchVGPRThreshold) - AddressableArchVGPRs)
166
+ : (getAGPRNum (ArchVGPRThreshold) - MaxVGPRs)),
167
+ 0 );
164
168
unsigned OtherExcessAGPR = std::max (
165
- static_cast <int >(ST.hasGFX90AInsts () ? (O.getAGPRNum () - MaxArchVGPRs)
166
- : (O.getAGPRNum () - MaxVGPRs)),
169
+ static_cast <int >(
170
+ ST.hasGFX90AInsts ()
171
+ ? (O.getAGPRNum (ArchVGPRThreshold) - AddressableArchVGPRs)
172
+ : (O.getAGPRNum (ArchVGPRThreshold) - MaxVGPRs)),
167
173
0 );
168
174
169
175
bool ExcessRP = ExcessSGPR || ExcessVGPR || ExcessArchVGPR || ExcessAGPR;
@@ -184,14 +190,21 @@ bool GCNRegPressure::less(const MachineFunction &MF, const GCNRegPressure &O,
184
190
return VGPRDiff > 0 ;
185
191
if (SGPRDiff != 0 ) {
186
192
unsigned PureExcessVGPR =
187
- std::max (static_cast <int >(getVGPRNum (ST.hasGFX90AInsts ()) - MaxVGPRs),
193
+ std::max (static_cast <int >(
194
+ getVGPRNum (ST.hasGFX90AInsts (), ArchVGPRThreshold) -
195
+ MaxVGPRs),
188
196
0 ) +
189
- std::max (static_cast <int >(getVGPRNum (false ) - MaxArchVGPRs), 0 );
197
+ std::max (static_cast <int >(getVGPRNum (false , ArchVGPRThreshold) -
198
+ AddressableArchVGPRs),
199
+ 0 );
190
200
unsigned OtherPureExcessVGPR =
191
- std::max (
192
- static_cast <int >(O.getVGPRNum (ST.hasGFX90AInsts ()) - MaxVGPRs),
193
- 0 ) +
194
- std::max (static_cast <int >(O.getVGPRNum (false ) - MaxArchVGPRs), 0 );
201
+ std::max (static_cast <int >(
202
+ O.getVGPRNum (ST.hasGFX90AInsts (), ArchVGPRThreshold) -
203
+ MaxVGPRs),
204
+ 0 ) +
205
+ std::max (static_cast <int >(O.getVGPRNum (false , ArchVGPRThreshold) -
206
+ AddressableArchVGPRs),
207
+ 0 );
195
208
196
209
// If we have a special case where there is a tie in excess VGPR, but one
197
210
// of the pressures has VGPR usage from SGPR spills, prefer the pressure
@@ -221,38 +234,45 @@ bool GCNRegPressure::less(const MachineFunction &MF, const GCNRegPressure &O,
221
234
if (SW != OtherSW)
222
235
return SW < OtherSW;
223
236
} else {
224
- auto VW = getVGPRTuplesWeight ();
225
- auto OtherVW = O.getVGPRTuplesWeight ();
237
+ auto VW = getVGPRTuplesWeight (ArchVGPRThreshold );
238
+ auto OtherVW = O.getVGPRTuplesWeight (ArchVGPRThreshold );
226
239
if (VW != OtherVW)
227
240
return VW < OtherVW;
228
241
}
229
242
}
230
243
231
244
// Give final precedence to lower general RP.
232
- return SGPRImportant ? (getSGPRNum () < O.getSGPRNum ()):
233
- (getVGPRNum (ST.hasGFX90AInsts ()) <
234
- O.getVGPRNum (ST.hasGFX90AInsts ()));
245
+ return SGPRImportant ? (getSGPRNum () < O.getSGPRNum ())
246
+ : (getVGPRNum (ST.hasGFX90AInsts (), ArchVGPRThreshold ) <
247
+ O.getVGPRNum (ST.hasGFX90AInsts (), ArchVGPRThreshold ));
235
248
}
236
249
237
250
Printable llvm::print (const GCNRegPressure &RP, const GCNSubtarget *ST,
238
- unsigned DynamicVGPRBlockSize) {
239
- return Printable ([&RP, ST, DynamicVGPRBlockSize](raw_ostream &OS) {
240
- OS << " VGPRs: " << RP.getArchVGPRNum () << ' '
241
- << " AGPRs: " << RP.getAGPRNum ();
242
- if (ST)
243
- OS << " (O"
244
- << ST->getOccupancyWithNumVGPRs (RP.getVGPRNum (ST->hasGFX90AInsts ()),
245
- DynamicVGPRBlockSize)
246
- << ' )' ;
247
- OS << " , SGPRs: " << RP.getSGPRNum ();
248
- if (ST)
249
- OS << " (O" << ST->getOccupancyWithNumSGPRs (RP.getSGPRNum ()) << ' )' ;
250
- OS << " , LVGPR WT: " << RP.getVGPRTuplesWeight ()
251
- << " , LSGPR WT: " << RP.getSGPRTuplesWeight ();
252
- if (ST)
253
- OS << " -> Occ: " << RP.getOccupancy (*ST, DynamicVGPRBlockSize);
254
- OS << ' \n ' ;
255
- });
251
+ unsigned DynamicVGPRBlockSize,
252
+ const MachineFunction *MF) {
253
+ unsigned ArchVGPRThreshold = std::numeric_limits<unsigned int >::max ();
254
+ if (ST && MF)
255
+ ArchVGPRThreshold = ST->getMaxNumVectorRegs (MF->getFunction ()).first ;
256
+
257
+ return Printable (
258
+ [&RP, ST, DynamicVGPRBlockSize, ArchVGPRThreshold, MF](raw_ostream &OS) {
259
+ OS << " VGPRs: " << RP.getArchVGPRNum (ArchVGPRThreshold) << ' '
260
+ << " AGPRs: " << RP.getAGPRNum (ArchVGPRThreshold);
261
+ if (ST)
262
+ OS << " (O"
263
+ << ST->getOccupancyWithNumVGPRs (
264
+ RP.getVGPRNum (ST->hasGFX90AInsts (), ArchVGPRThreshold),
265
+ DynamicVGPRBlockSize)
266
+ << ' )' ;
267
+ OS << " , SGPRs: " << RP.getSGPRNum ();
268
+ if (ST)
269
+ OS << " (O" << ST->getOccupancyWithNumSGPRs (RP.getSGPRNum ()) << ' )' ;
270
+ OS << " , LVGPR WT: " << RP.getVGPRTuplesWeight (ArchVGPRThreshold)
271
+ << " , LSGPR WT: " << RP.getSGPRTuplesWeight ();
272
+ if (ST)
273
+ OS << " -> Occ: " << RP.getOccupancy (*MF);
274
+ OS << ' \n ' ;
275
+ });
256
276
}
257
277
258
278
static LaneBitmask getDefRegMask (const MachineOperand &MO,
@@ -398,8 +418,9 @@ void GCNRPTarget::setRegLimits(unsigned NumSGPRs, unsigned NumVGPRs,
398
418
const GCNSubtarget &ST = MF.getSubtarget <GCNSubtarget>();
399
419
unsigned DynamicVGPRBlockSize =
400
420
MF.getInfo <SIMachineFunctionInfo>()->getDynamicVGPRBlockSize ();
421
+ AddressableNumArchVGPRs = ST.getAddressableNumArchVGPRs ();
401
422
MaxSGPRs = std::min (ST.getAddressableNumSGPRs (), NumSGPRs);
402
- MaxVGPRs = std::min (ST. getAddressableNumArchVGPRs () , NumVGPRs);
423
+ MaxVGPRs = std::min (AddressableNumArchVGPRs , NumVGPRs);
403
424
MaxUnifiedVGPRs =
404
425
ST.hasGFX90AInsts ()
405
426
? std::min (ST.getAddressableNumVGPRs (DynamicVGPRBlockSize), NumVGPRs)
@@ -414,15 +435,21 @@ bool GCNRPTarget::isSaveBeneficial(Register Reg,
414
435
415
436
if (SRI->isSGPRClass (RC))
416
437
return RP.getSGPRNum () > MaxSGPRs;
417
- unsigned NumVGPRs =
418
- SRI->isAGPRClass (RC) ? RP.getAGPRNum () : RP.getArchVGPRNum ();
438
+
439
+ bool ShouldUseAGPR =
440
+ SRI->isAGPRClass (RC) ||
441
+ (SRI->isVectorSuperClass (RC) &&
442
+ RP.getArchVGPRNum (AddressableNumArchVGPRs) >= AddressableNumArchVGPRs);
443
+ unsigned NumVGPRs = ShouldUseAGPR
444
+ ? RP.getAGPRNum (AddressableNumArchVGPRs)
445
+ : RP.getArchVGPRNum (AddressableNumArchVGPRs);
419
446
return isVGPRBankSaveBeneficial (NumVGPRs);
420
447
}
421
448
422
449
bool GCNRPTarget::satisfied () const {
423
450
if (RP.getSGPRNum () > MaxSGPRs)
424
451
return false ;
425
- if (RP.getVGPRNum (false ) > MaxVGPRs &&
452
+ if (RP.getVGPRNum (false , AddressableNumArchVGPRs ) > MaxVGPRs &&
426
453
(!CombineVGPRSavings || !satisifiesVGPRBanksTarget ()))
427
454
return false ;
428
455
return satisfiesUnifiedTarget ();
@@ -876,10 +903,13 @@ bool GCNRegPressurePrinter::runOnMachineFunction(MachineFunction &MF) {
876
903
877
904
OS << " ---\n name: " << MF.getName () << " \n body: |\n " ;
878
905
879
- auto printRP = [](const GCNRegPressure &RP) {
880
- return Printable ([&RP](raw_ostream &OS) {
906
+ auto printRP = [&MF ](const GCNRegPressure &RP) {
907
+ return Printable ([&RP, &MF ](raw_ostream &OS) {
881
908
OS << format (PFX " %-5d" , RP.getSGPRNum ())
882
- << format (" %-5d" , RP.getVGPRNum (false ));
909
+ << format (" %-5d" , RP.getVGPRNum (false , MF.getSubtarget <GCNSubtarget>()
910
+ .getMaxNumVectorRegs (
911
+ MF.getFunction ())
912
+ .first ));
883
913
});
884
914
};
885
915
0 commit comments