Skip to content

[NVPTX] lower VECREDUCE min/max to 3-input on sm_100+ #136253

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
139 changes: 139 additions & 0 deletions llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -850,6 +850,17 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
if (STI.allowFP16Math() || STI.hasBF16Math())
setTargetDAGCombine(ISD::SETCC);

// Vector reduction operations. These may be turned into shuffle or tree
// reductions depending on what instructions are available for each type.
for (MVT VT : MVT::fixedlen_vector_valuetypes()) {
MVT EltVT = VT.getVectorElementType();
if (EltVT == MVT::f32 || EltVT == MVT::f64) {
setOperationAction({ISD::VECREDUCE_FMAX, ISD::VECREDUCE_FMIN,
ISD::VECREDUCE_FMAXIMUM, ISD::VECREDUCE_FMINIMUM},
VT, Custom);
}
}

// Promote fp16 arithmetic if fp16 hardware isn't available or the
// user passed --nvptx-no-fp16-math. The flag is useful because,
// although sm_53+ GPUs have some sort of FP16 support in
Expand Down Expand Up @@ -1093,6 +1104,10 @@ const char *NVPTXTargetLowering::getTargetNodeName(unsigned Opcode) const {
MAKE_CASE(NVPTXISD::BFI)
MAKE_CASE(NVPTXISD::PRMT)
MAKE_CASE(NVPTXISD::FCOPYSIGN)
MAKE_CASE(NVPTXISD::FMAXNUM3)
MAKE_CASE(NVPTXISD::FMINNUM3)
MAKE_CASE(NVPTXISD::FMAXIMUM3)
MAKE_CASE(NVPTXISD::FMINIMUM3)
MAKE_CASE(NVPTXISD::DYNAMIC_STACKALLOC)
MAKE_CASE(NVPTXISD::STACKRESTORE)
MAKE_CASE(NVPTXISD::STACKSAVE)
Expand Down Expand Up @@ -1900,6 +1915,125 @@ static SDValue getPRMT(SDValue A, SDValue B, uint64_t Selector, SDLoc DL,
return getPRMT(A, B, DAG.getConstant(Selector, DL, MVT::i32), DL, DAG, Mode);
}

/// Reduces the elements using the scalar operations provided. The operations
/// are sorted descending in number of inputs they take. The flags on the
/// original reduction operation will be propagated to each scalar operation.
/// Nearby elements are grouped in tree reduction, unlike the shuffle reduction
/// used in ExpandReductions and SelectionDAG.
static SDValue buildTreeReduction(
const SmallVector<SDValue> &Elements, EVT EltTy,
ArrayRef<std::pair<unsigned /*NodeType*/, unsigned /*NumInputs*/>> Ops,
const SDLoc &DL, const SDNodeFlags Flags, SelectionDAG &DAG) {
// Build the reduction tree at each level, starting with all the elements.
SmallVector<SDValue> Level = Elements;

unsigned OpIdx = 0;
while (Level.size() > 1) {
// Try to reduce this level using the current operator.
const auto [Op, NumInputs] = Ops[OpIdx];

// Build the next level by partially reducing all elements.
SmallVector<SDValue> ReducedLevel;
unsigned I = 0, E = Level.size();
for (; I + NumInputs <= E; I += NumInputs) {
// Reduce elements in groups of [NumInputs], as much as possible.
ReducedLevel.push_back(DAG.getNode(
Op, DL, EltTy, ArrayRef<SDValue>(Level).slice(I, NumInputs), Flags));
}

if (I < E) {
// Handle leftover elements.

if (ReducedLevel.empty()) {
// We didn't reduce anything at this level. We need to pick a smaller
// operator.
++OpIdx;
assert(OpIdx < Ops.size() && "no smaller operators for reduction");
continue;
}

// We reduced some things but there's still more left, meaning the
// operator's number of inputs doesn't evenly divide this level size. Move
// these elements to the next level.
for (; I < E; ++I)
ReducedLevel.push_back(Level[I]);
}

// Process the next level.
Level = ReducedLevel;
}

return *Level.begin();
}

// Get scalar reduction opcode
static ISD::NodeType getScalarOpcodeForReduction(unsigned ReductionOpcode) {
switch (ReductionOpcode) {
case ISD::VECREDUCE_FMAX:
return ISD::FMAXNUM;
case ISD::VECREDUCE_FMIN:
return ISD::FMINNUM;
case ISD::VECREDUCE_FMAXIMUM:
return ISD::FMAXIMUM;
case ISD::VECREDUCE_FMINIMUM:
return ISD::FMINIMUM;
default:
llvm_unreachable("unhandled reduction opcode");
}
}

/// Get 3-input scalar reduction opcode
static std::optional<NVPTXISD::NodeType>
getScalar3OpcodeForReduction(unsigned ReductionOpcode) {
switch (ReductionOpcode) {
case ISD::VECREDUCE_FMAX:
return NVPTXISD::FMAXNUM3;
case ISD::VECREDUCE_FMIN:
return NVPTXISD::FMINNUM3;
case ISD::VECREDUCE_FMAXIMUM:
return NVPTXISD::FMAXIMUM3;
case ISD::VECREDUCE_FMINIMUM:
return NVPTXISD::FMINIMUM3;
default:
return std::nullopt;
}
}

/// Lower reductions to either a sequence of operations or a tree if
/// reassociations are allowed. This method will use larger operations like
/// max3/min3 when the target supports them.
SDValue NVPTXTargetLowering::LowerVECREDUCE(SDValue Op,
SelectionDAG &DAG) const {
SDLoc DL(Op);
const SDNodeFlags Flags = Op->getFlags();
SDValue Vector = Op.getOperand(0);

const unsigned Opcode = Op->getOpcode();
const EVT EltTy = Vector.getValueType().getVectorElementType();

// Whether we can use 3-input min/max when expanding the reduction.
const bool CanUseMinMax3 =
EltTy == MVT::f32 && STI.getSmVersion() >= 100 &&
STI.getPTXVersion() >= 88 &&
(Opcode == ISD::VECREDUCE_FMAX || Opcode == ISD::VECREDUCE_FMIN ||
Opcode == ISD::VECREDUCE_FMAXIMUM || Opcode == ISD::VECREDUCE_FMINIMUM);

// A list of SDNode opcodes with equivalent semantics, sorted descending by
// number of inputs they take.
SmallVector<std::pair<unsigned /*Op*/, unsigned /*NumIn*/>, 2> ScalarOps;

if (auto Opcode3Elem = getScalar3OpcodeForReduction(Opcode);
CanUseMinMax3 && Opcode3Elem)
ScalarOps.push_back({*Opcode3Elem, 3});
ScalarOps.push_back({getScalarOpcodeForReduction(Opcode), 2});

// Otherwise, handle the reduction here.
SmallVector<SDValue> Elements;
DAG.ExtractVectorElements(Vector, Elements);

return buildTreeReduction(Elements, EltTy, ScalarOps, DL, Flags, DAG);
}

SDValue NVPTXTargetLowering::LowerBITCAST(SDValue Op, SelectionDAG &DAG) const {
// Handle bitcasting from v2i8 without hitting the default promotion
// strategy which goes through stack memory.
Expand Down Expand Up @@ -2779,6 +2913,11 @@ NVPTXTargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const {
return LowerVECTOR_SHUFFLE(Op, DAG);
case ISD::CONCAT_VECTORS:
return LowerCONCAT_VECTORS(Op, DAG);
case ISD::VECREDUCE_FMAX:
case ISD::VECREDUCE_FMIN:
case ISD::VECREDUCE_FMAXIMUM:
case ISD::VECREDUCE_FMINIMUM:
return LowerVECREDUCE(Op, DAG);
case ISD::STORE:
return LowerSTORE(Op, DAG);
case ISD::LOAD:
Expand Down
6 changes: 6 additions & 0 deletions llvm/lib/Target/NVPTX/NVPTXISelLowering.h
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,11 @@ enum NodeType : unsigned {
UNPACK_VECTOR,

FCOPYSIGN,
FMAXNUM3,
FMINNUM3,
FMAXIMUM3,
FMINIMUM3,

DYNAMIC_STACKALLOC,
STACKRESTORE,
STACKSAVE,
Expand Down Expand Up @@ -286,6 +291,7 @@ class NVPTXTargetLowering : public TargetLowering {

SDValue LowerBUILD_VECTOR(SDValue Op, SelectionDAG &DAG) const;
SDValue LowerCONCAT_VECTORS(SDValue Op, SelectionDAG &DAG) const;
SDValue LowerVECREDUCE(SDValue Op, SelectionDAG &DAG) const;
SDValue LowerEXTRACT_VECTOR_ELT(SDValue Op, SelectionDAG &DAG) const;
SDValue LowerINSERT_VECTOR_ELT(SDValue Op, SelectionDAG &DAG) const;
SDValue LowerVECTOR_SHUFFLE(SDValue Op, SelectionDAG &DAG) const;
Expand Down
44 changes: 44 additions & 0 deletions llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
Original file line number Diff line number Diff line change
Expand Up @@ -356,6 +356,36 @@ multiclass FMINIMUMMAXIMUM<string OpcStr, bit NaN, SDNode OpNode> {
Requires<[hasBF16Math, hasSM<80>, hasPTX<70>]>;
}

// Template for 3-input minimum/maximum instructions
// (sm_100+/PTX 8.8 and f32 only)
//
// Also defines ftz (flush subnormal inputs and results to sign-preserving
// zero) variants for fp32 functions.
multiclass FMINIMUMMAXIMUM3<string OpcStr, bit NaN, SDNode OpNode> {
defvar nan_str = !if(NaN, ".NaN", "");
def f32rrr :
BasicFlagsNVPTXInst<(outs B32:$dst),
(ins B32:$a, B32:$b, B32:$c),
(ins FTZFlag:$ftz),
OpcStr # "$ftz" # nan_str # ".f32",
[(set f32:$dst, (OpNode f32:$a, f32:$b, f32:$c))]>,
Requires<[hasPTX<88>, hasSM<100>]>;
def f32rri :
BasicFlagsNVPTXInst<(outs B32:$dst),
(ins B32:$a, B32:$b, f32imm:$c),
(ins FTZFlag:$ftz),
OpcStr # "$ftz" # nan_str # ".f32",
[(set f32:$dst, (OpNode f32:$a, f32:$b, fpimm:$c))]>,
Requires<[hasPTX<88>, hasSM<100>]>;
def f32rii :
BasicFlagsNVPTXInst<(outs B32:$dst),
(ins B32:$a, f32imm:$b, f32imm:$c),
(ins FTZFlag:$ftz),
OpcStr # "$ftz" # nan_str # ".f32",
[(set f32:$dst, (OpNode f32:$a, fpimm:$b, fpimm:$c))]>,
Requires<[hasPTX<88>, hasSM<100>]>;
}

// Template for instructions which take three FP args. The
// instructions are named "<OpcStr>.f<Width>" (e.g. "add.f64").
//
Expand Down Expand Up @@ -971,6 +1001,20 @@ defm MAX : FMINIMUMMAXIMUM<"max", /* NaN */ false, fmaxnum>;
defm MIN_NAN : FMINIMUMMAXIMUM<"min", /* NaN */ true, fminimum>;
defm MAX_NAN : FMINIMUMMAXIMUM<"max", /* NaN */ true, fmaximum>;

def nvptx_fminnum3 : SDNode<"NVPTXISD::FMINNUM3", SDTFPTernaryOp,
[SDNPCommutative]>;
def nvptx_fmaxnum3 : SDNode<"NVPTXISD::FMAXNUM3", SDTFPTernaryOp,
[SDNPCommutative]>;
def nvptx_fminimum3 : SDNode<"NVPTXISD::FMINIMUM3", SDTFPTernaryOp,
[SDNPCommutative]>;
def nvptx_fmaximum3 : SDNode<"NVPTXISD::FMAXIMUM3", SDTFPTernaryOp,
[SDNPCommutative]>;

defm FMIN3 : FMINIMUMMAXIMUM3<"min", /* NaN */ false, nvptx_fminnum3>;
defm FMAX3 : FMINIMUMMAXIMUM3<"max", /* NaN */ false, nvptx_fmaxnum3>;
defm FMINNAN3 : FMINIMUMMAXIMUM3<"min", /* NaN */ true, nvptx_fminimum3>;
defm FMAXNAN3 : FMINIMUMMAXIMUM3<"max", /* NaN */ true, nvptx_fmaximum3>;

defm FABS : F2<"abs", fabs>;
defm FNEG : F2<"neg", fneg>;
defm FABS_H: F2_Support_Half<"abs", fabs>;
Expand Down
7 changes: 7 additions & 0 deletions llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.h
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,13 @@ class NVPTXTTIImpl final : public BasicTTIImplBase<NVPTXTTIImpl> {
}
unsigned getMinVectorRegisterBitWidth() const override { return 32; }

bool shouldExpandReduction(const IntrinsicInst *II) const override {
// Turn off ExpandReductions pass for NVPTX, which doesn't have advanced
// swizzling operations. Our backend/Selection DAG can expand these
// reductions with less movs.
return false;
}

// We don't want to prevent inlining because of target-cpu and -features
// attributes that were added to newer versions of LLVM/Clang: There are
// no incompatible functions in PTX, ptxas will throw errors in such cases.
Expand Down
Loading