Skip to content

Handle VECREDUCE intrinsics in NVPTX backend #136253

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

Prince781
Copy link
Contributor

@Prince781 Prince781 commented Apr 18, 2025

Improve handling of VECREDUCE intrinsics in NVPTX. Also adds support for min3/max3 from PTX ISA 8.8.

@llvmbot
Copy link
Member

llvmbot commented Apr 18, 2025

@llvm/pr-subscribers-backend-nvptx

Author: Princeton Ferro (Prince781)

Changes

Lower VECREDUCE intrinsics to tree reductions when reassociations are allowed. And add support for min3/max3 from PTX ISA 8.8 (to be released next week).


Patch is 94.44 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/136253.diff

6 Files Affected:

  • (modified) llvm/lib/Target/NVPTX/NVPTX.td (+8-2)
  • (modified) llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp (+230)
  • (modified) llvm/lib/Target/NVPTX/NVPTXISelLowering.h (+6)
  • (modified) llvm/lib/Target/NVPTX/NVPTXInstrInfo.td (+54)
  • (modified) llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.h (+2)
  • (added) llvm/test/CodeGen/NVPTX/reduction-intrinsics.ll (+1908)
diff --git a/llvm/lib/Target/NVPTX/NVPTX.td b/llvm/lib/Target/NVPTX/NVPTX.td
index 5467ae011a208..d4dc278cfa648 100644
--- a/llvm/lib/Target/NVPTX/NVPTX.td
+++ b/llvm/lib/Target/NVPTX/NVPTX.td
@@ -36,17 +36,19 @@ class FeaturePTX<int version>:
 
 foreach sm = [20, 21, 30, 32, 35, 37, 50, 52, 53,
               60, 61, 62, 70, 72, 75, 80, 86, 87,
-              89, 90, 100, 101, 120] in
+              89, 90, 100, 101, 103, 120, 121] in
   def SM#sm: FeatureSM<""#sm, !mul(sm, 10)>;
 
 def SM90a: FeatureSM<"90a", 901>;
 def SM100a: FeatureSM<"100a", 1001>;
 def SM101a: FeatureSM<"101a", 1011>;
+def SM103a: FeatureSM<"103a", 1031>;
 def SM120a: FeatureSM<"120a", 1201>;
+def SM121a: FeatureSM<"121a", 1211>;
 
 foreach version = [32, 40, 41, 42, 43, 50, 60, 61, 62, 63, 64, 65,
                    70, 71, 72, 73, 74, 75, 76, 77, 78,
-                   80, 81, 82, 83, 84, 85, 86, 87] in
+                   80, 81, 82, 83, 84, 85, 86, 87, 88] in
   def PTX#version: FeaturePTX<version>;
 
 //===----------------------------------------------------------------------===//
@@ -81,8 +83,12 @@ def : Proc<"sm_100", [SM100, PTX86]>;
 def : Proc<"sm_100a", [SM100a, PTX86]>;
 def : Proc<"sm_101", [SM101, PTX86]>;
 def : Proc<"sm_101a", [SM101a, PTX86]>;
+def : Proc<"sm_103", [SM103, PTX88]>;
+def : Proc<"sm_103a", [SM103a, PTX88]>;
 def : Proc<"sm_120", [SM120, PTX87]>;
 def : Proc<"sm_120a", [SM120a, PTX87]>;
+def : Proc<"sm_121", [SM121, PTX88]>;
+def : Proc<"sm_121a", [SM121a, PTX88]>;
 
 def NVPTXInstrInfo : InstrInfo {
 }
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
index 9bde2a976e164..3a14aa47e5065 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
@@ -831,6 +831,26 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
   if (STI.allowFP16Math() || STI.hasBF16Math())
     setTargetDAGCombine(ISD::SETCC);
 
+  // Vector reduction operations. These are transformed into a tree evaluation
+  // of nodes which may initially be illegal.
+  for (MVT VT : MVT::fixedlen_vector_valuetypes()) {
+    MVT EltVT = VT.getVectorElementType();
+    if (EltVT == MVT::f16 || EltVT == MVT::bf16 || EltVT == MVT::f32 ||
+        EltVT == MVT::f64) {
+      setOperationAction({ISD::VECREDUCE_FADD, ISD::VECREDUCE_FMUL,
+                          ISD::VECREDUCE_SEQ_FADD, ISD::VECREDUCE_SEQ_FMUL,
+                          ISD::VECREDUCE_FMAX, ISD::VECREDUCE_FMIN,
+                          ISD::VECREDUCE_FMAXIMUM, ISD::VECREDUCE_FMINIMUM},
+                         VT, Custom);
+    } else if (EltVT.isScalarInteger()) {
+      setOperationAction(
+          {ISD::VECREDUCE_ADD, ISD::VECREDUCE_MUL, ISD::VECREDUCE_AND,
+           ISD::VECREDUCE_OR, ISD::VECREDUCE_XOR, ISD::VECREDUCE_SMAX,
+           ISD::VECREDUCE_SMIN, ISD::VECREDUCE_UMAX, ISD::VECREDUCE_UMIN},
+          VT, Custom);
+    }
+  }
+
   // Promote fp16 arithmetic if fp16 hardware isn't available or the
   // user passed --nvptx-no-fp16-math. The flag is useful because,
   // although sm_53+ GPUs have some sort of FP16 support in
@@ -1082,6 +1102,10 @@ const char *NVPTXTargetLowering::getTargetNodeName(unsigned Opcode) const {
     MAKE_CASE(NVPTXISD::BFI)
     MAKE_CASE(NVPTXISD::PRMT)
     MAKE_CASE(NVPTXISD::FCOPYSIGN)
+    MAKE_CASE(NVPTXISD::FMAXNUM3)
+    MAKE_CASE(NVPTXISD::FMINNUM3)
+    MAKE_CASE(NVPTXISD::FMAXIMUM3)
+    MAKE_CASE(NVPTXISD::FMINIMUM3)
     MAKE_CASE(NVPTXISD::DYNAMIC_STACKALLOC)
     MAKE_CASE(NVPTXISD::STACKRESTORE)
     MAKE_CASE(NVPTXISD::STACKSAVE)
@@ -2136,6 +2160,194 @@ NVPTXTargetLowering::LowerCONCAT_VECTORS(SDValue Op, SelectionDAG &DAG) const {
   return DAG.getBuildVector(Node->getValueType(0), dl, Ops);
 }
 
+/// A generic routine for constructing a tree reduction on a vector operand.
+/// This method differs from iterative splitting in DAGTypeLegalizer by
+/// progressively grouping elements bottom-up.
+static SDValue BuildTreeReduction(
+    const SmallVector<SDValue> &Elements, EVT EltTy,
+    ArrayRef<std::pair<unsigned /*NodeType*/, unsigned /*NumInputs*/>> Ops,
+    const SDLoc &DL, const SDNodeFlags Flags, SelectionDAG &DAG) {
+  // now build the computation graph in place at each level
+  SmallVector<SDValue> Level = Elements;
+  unsigned OpIdx = 0;
+  while (Level.size() > 1) {
+    const auto [DefaultScalarOp, DefaultGroupSize] = Ops[OpIdx];
+
+    // partially reduce all elements in level
+    SmallVector<SDValue> ReducedLevel;
+    unsigned I = 0, E = Level.size();
+    for (; I + DefaultGroupSize <= E; I += DefaultGroupSize) {
+      // Reduce elements in groups of [DefaultGroupSize], as much as possible.
+      ReducedLevel.push_back(DAG.getNode(
+          DefaultScalarOp, DL, EltTy,
+          ArrayRef<SDValue>(Level).slice(I, DefaultGroupSize), Flags));
+    }
+
+    if (I < E) {
+      if (ReducedLevel.empty()) {
+        // The current operator requires more inputs than there are operands at
+        // this level. Pick a smaller operator and retry.
+        ++OpIdx;
+        assert(OpIdx < Ops.size() && "no smaller operators for reduction");
+        continue;
+      }
+
+      // Otherwise, we just have a remainder, which we push to the next level.
+      for (; I < E; ++I)
+        ReducedLevel.push_back(Level[I]);
+    }
+    Level = ReducedLevel;
+  }
+
+  return *Level.begin();
+}
+
+/// Lower reductions to either a sequence of operations or a tree if
+/// reassociations are allowed. This method will use larger operations like
+/// max3/min3 when the target supports them.
+SDValue NVPTXTargetLowering::LowerVECREDUCE(SDValue Op,
+                                            SelectionDAG &DAG) const {
+  SDLoc DL(Op);
+  const SDNodeFlags Flags = Op->getFlags();
+  SDValue Vector;
+  SDValue Accumulator;
+  if (Op->getOpcode() == ISD::VECREDUCE_SEQ_FADD ||
+      Op->getOpcode() == ISD::VECREDUCE_SEQ_FMUL) {
+    // special case with accumulator as first arg
+    Accumulator = Op.getOperand(0);
+    Vector = Op.getOperand(1);
+  } else {
+    // default case
+    Vector = Op.getOperand(0);
+  }
+  EVT EltTy = Vector.getValueType().getVectorElementType();
+  const bool CanUseMinMax3 = EltTy == MVT::f32 && STI.getSmVersion() >= 100 &&
+                             STI.getPTXVersion() >= 88;
+
+  // A list of SDNode opcodes with equivalent semantics, sorted descending by
+  // number of inputs they take.
+  SmallVector<std::pair<unsigned /*Op*/, unsigned /*NumIn*/>, 2> ScalarOps;
+  bool IsReassociatable;
+
+  switch (Op->getOpcode()) {
+  case ISD::VECREDUCE_FADD:
+  case ISD::VECREDUCE_SEQ_FADD:
+    ScalarOps = {{ISD::FADD, 2}};
+    IsReassociatable = false;
+    break;
+  case ISD::VECREDUCE_FMUL:
+  case ISD::VECREDUCE_SEQ_FMUL:
+    ScalarOps = {{ISD::FMUL, 2}};
+    IsReassociatable = false;
+    break;
+  case ISD::VECREDUCE_FMAX:
+    if (CanUseMinMax3)
+      ScalarOps.push_back({NVPTXISD::FMAXNUM3, 3});
+    ScalarOps.push_back({ISD::FMAXNUM, 2});
+    IsReassociatable = false;
+    break;
+  case ISD::VECREDUCE_FMIN:
+    if (CanUseMinMax3)
+      ScalarOps.push_back({NVPTXISD::FMINNUM3, 3});
+    ScalarOps.push_back({ISD::FMINNUM, 2});
+    IsReassociatable = false;
+    break;
+  case ISD::VECREDUCE_FMAXIMUM:
+    if (CanUseMinMax3)
+      ScalarOps.push_back({NVPTXISD::FMAXIMUM3, 3});
+    ScalarOps.push_back({ISD::FMAXIMUM, 2});
+    IsReassociatable = false;
+    break;
+  case ISD::VECREDUCE_FMINIMUM:
+    if (CanUseMinMax3)
+      ScalarOps.push_back({NVPTXISD::FMINIMUM3, 3});
+    ScalarOps.push_back({ISD::FMINIMUM, 2});
+    IsReassociatable = false;
+    break;
+  case ISD::VECREDUCE_ADD:
+    ScalarOps = {{ISD::ADD, 2}};
+    IsReassociatable = true;
+    break;
+  case ISD::VECREDUCE_MUL:
+    ScalarOps = {{ISD::MUL, 2}};
+    IsReassociatable = true;
+    break;
+  case ISD::VECREDUCE_UMAX:
+    ScalarOps = {{ISD::UMAX, 2}};
+    IsReassociatable = true;
+    break;
+  case ISD::VECREDUCE_UMIN:
+    ScalarOps = {{ISD::UMIN, 2}};
+    IsReassociatable = true;
+    break;
+  case ISD::VECREDUCE_SMAX:
+    ScalarOps = {{ISD::SMAX, 2}};
+    IsReassociatable = true;
+    break;
+  case ISD::VECREDUCE_SMIN:
+    ScalarOps = {{ISD::SMIN, 2}};
+    IsReassociatable = true;
+    break;
+  case ISD::VECREDUCE_AND:
+    ScalarOps = {{ISD::AND, 2}};
+    IsReassociatable = true;
+    break;
+  case ISD::VECREDUCE_OR:
+    ScalarOps = {{ISD::OR, 2}};
+    IsReassociatable = true;
+    break;
+  case ISD::VECREDUCE_XOR:
+    ScalarOps = {{ISD::XOR, 2}};
+    IsReassociatable = true;
+    break;
+  default:
+    llvm_unreachable("unhandled vecreduce operation");
+  }
+
+  EVT VectorTy = Vector.getValueType();
+  const unsigned NumElts = VectorTy.getVectorNumElements();
+
+  // scalarize vector
+  SmallVector<SDValue> Elements(NumElts);
+  for (unsigned I = 0, E = NumElts; I != E; ++I) {
+    Elements[I] = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, EltTy, Vector,
+                              DAG.getConstant(I, DL, MVT::i64));
+  }
+
+  // Lower to tree reduction.
+  if (IsReassociatable || Flags.hasAllowReassociation()) {
+    // we don't expect an accumulator for reassociatable vector reduction ops
+    assert(!Accumulator && "unexpected accumulator");
+    return BuildTreeReduction(Elements, EltTy, ScalarOps, DL, Flags, DAG);
+  }
+
+  // Lower to sequential reduction.
+  for (unsigned OpIdx = 0, I = 0; I < NumElts; ++OpIdx) {
+    assert(OpIdx < ScalarOps.size() && "no smaller operators for reduction");
+    const auto [DefaultScalarOp, DefaultGroupSize] = ScalarOps[OpIdx];
+
+    if (!Accumulator) {
+      if (I + DefaultGroupSize <= NumElts) {
+        Accumulator = DAG.getNode(
+            DefaultScalarOp, DL, EltTy,
+            ArrayRef(Elements).slice(I, I + DefaultGroupSize), Flags);
+        I += DefaultGroupSize;
+      }
+    }
+
+    if (Accumulator) {
+      for (; I + (DefaultGroupSize - 1) <= NumElts; I += DefaultGroupSize - 1) {
+        SmallVector<SDValue> Operands = {Accumulator};
+        for (unsigned K = 0; K < DefaultGroupSize - 1; ++K)
+          Operands.push_back(Elements[I + K]);
+        Accumulator = DAG.getNode(DefaultScalarOp, DL, EltTy, Operands, Flags);
+      }
+    }
+  }
+
+  return Accumulator;
+}
+
 SDValue NVPTXTargetLowering::LowerBITCAST(SDValue Op, SelectionDAG &DAG) const {
   // Handle bitcasting from v2i8 without hitting the default promotion
   // strategy which goes through stack memory.
@@ -2879,6 +3091,24 @@ NVPTXTargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const {
     return LowerVECTOR_SHUFFLE(Op, DAG);
   case ISD::CONCAT_VECTORS:
     return LowerCONCAT_VECTORS(Op, DAG);
+  case ISD::VECREDUCE_FADD:
+  case ISD::VECREDUCE_FMUL:
+  case ISD::VECREDUCE_SEQ_FADD:
+  case ISD::VECREDUCE_SEQ_FMUL:
+  case ISD::VECREDUCE_FMAX:
+  case ISD::VECREDUCE_FMIN:
+  case ISD::VECREDUCE_FMAXIMUM:
+  case ISD::VECREDUCE_FMINIMUM:
+  case ISD::VECREDUCE_ADD:
+  case ISD::VECREDUCE_MUL:
+  case ISD::VECREDUCE_UMAX:
+  case ISD::VECREDUCE_UMIN:
+  case ISD::VECREDUCE_SMAX:
+  case ISD::VECREDUCE_SMIN:
+  case ISD::VECREDUCE_AND:
+  case ISD::VECREDUCE_OR:
+  case ISD::VECREDUCE_XOR:
+    return LowerVECREDUCE(Op, DAG);
   case ISD::STORE:
     return LowerSTORE(Op, DAG);
   case ISD::LOAD:
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.h b/llvm/lib/Target/NVPTX/NVPTXISelLowering.h
index dd90746f6d9d6..0880f3e8edfe8 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.h
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.h
@@ -73,6 +73,11 @@ enum NodeType : unsigned {
   UNPACK_VECTOR,
 
   FCOPYSIGN,
+  FMAXNUM3,
+  FMINNUM3,
+  FMAXIMUM3,
+  FMINIMUM3,
+
   DYNAMIC_STACKALLOC,
   STACKRESTORE,
   STACKSAVE,
@@ -296,6 +301,7 @@ class NVPTXTargetLowering : public TargetLowering {
 
   SDValue LowerBUILD_VECTOR(SDValue Op, SelectionDAG &DAG) const;
   SDValue LowerCONCAT_VECTORS(SDValue Op, SelectionDAG &DAG) const;
+  SDValue LowerVECREDUCE(SDValue Op, SelectionDAG &DAG) const;
   SDValue LowerEXTRACT_VECTOR_ELT(SDValue Op, SelectionDAG &DAG) const;
   SDValue LowerINSERT_VECTOR_ELT(SDValue Op, SelectionDAG &DAG) const;
   SDValue LowerVECTOR_SHUFFLE(SDValue Op, SelectionDAG &DAG) const;
diff --git a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
index 16b489afddf5c..721b578ba72ee 100644
--- a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
+++ b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
@@ -368,6 +368,46 @@ multiclass FMINIMUMMAXIMUM<string OpcStr, bit NaN, SDNode OpNode> {
                Requires<[hasBF16Math, hasSM<80>, hasPTX<70>]>;
 }
 
+// 3-input min/max (sm_100+) for f32 only
+multiclass FMINIMUMMAXIMUM3<string OpcStr, SDNode OpNode> {
+   def f32rrr_ftz :
+     NVPTXInst<(outs Float32Regs:$dst),
+               (ins Float32Regs:$a, Float32Regs:$b, Float32Regs:$c),
+               !strconcat(OpcStr, ".ftz.f32 \t$dst, $a, $b, $c;"),
+               [(set f32:$dst, (OpNode f32:$a, f32:$b, f32:$c))]>,
+               Requires<[doF32FTZ, hasPTX<88>, hasSM<100>]>;
+   def f32rri_ftz :
+     NVPTXInst<(outs Float32Regs:$dst),
+               (ins Float32Regs:$a, Float32Regs:$b, f32imm:$c),
+               !strconcat(OpcStr, ".ftz.f32 \t$dst, $a, $b, $c;"),
+               [(set f32:$dst, (OpNode f32:$a, f32:$b, fpimm:$c))]>,
+               Requires<[doF32FTZ, hasPTX<88>, hasSM<100>]>;
+   def f32rii_ftz :
+     NVPTXInst<(outs Float32Regs:$dst),
+               (ins Float32Regs:$a, f32imm:$b, f32imm:$c),
+               !strconcat(OpcStr, ".ftz.f32 \t$dst, $a, $b, $c;"),
+               [(set f32:$dst, (OpNode f32:$a, fpimm:$b, fpimm:$c))]>,
+               Requires<[doF32FTZ, hasPTX<88>, hasSM<100>]>;
+   def f32rrr :
+     NVPTXInst<(outs Float32Regs:$dst),
+               (ins Float32Regs:$a, Float32Regs:$b, Float32Regs:$c),
+               !strconcat(OpcStr, ".f32 \t$dst, $a, $b, $c;"),
+               [(set f32:$dst, (OpNode f32:$a, f32:$b, f32:$c))]>,
+               Requires<[hasPTX<88>, hasSM<100>]>;
+   def f32rri :
+     NVPTXInst<(outs Float32Regs:$dst),
+               (ins Float32Regs:$a, Float32Regs:$b, f32imm:$c),
+               !strconcat(OpcStr, ".f32 \t$dst, $a, $b, $c;"),
+               [(set f32:$dst, (OpNode f32:$a, Float32Regs:$b, fpimm:$c))]>,
+               Requires<[hasPTX<88>, hasSM<100>]>;
+   def f32rii :
+     NVPTXInst<(outs Float32Regs:$dst),
+               (ins Float32Regs:$a, f32imm:$b, f32imm:$c),
+               !strconcat(OpcStr, ".f32 \t$dst, $a, $b, $c;"),
+               [(set f32:$dst, (OpNode f32:$a, fpimm:$b, fpimm:$c))]>,
+               Requires<[hasPTX<88>, hasSM<100>]>;
+}
+
 // Template for instructions which take three FP args.  The
 // instructions are named "<OpcStr>.f<Width>" (e.g. "add.f64").
 //
@@ -1101,6 +1141,20 @@ defm FMAX : FMINIMUMMAXIMUM<"max", /* NaN */ false, fmaxnum>;
 defm FMINNAN : FMINIMUMMAXIMUM<"min.NaN", /* NaN */ true, fminimum>;
 defm FMAXNAN : FMINIMUMMAXIMUM<"max.NaN", /* NaN */ true, fmaximum>;
 
+def nvptx_fminnum3 : SDNode<"NVPTXISD::FMINNUM3", SDTFPTernaryOp,
+                            [SDNPCommutative, SDNPAssociative]>;
+def nvptx_fmaxnum3 : SDNode<"NVPTXISD::FMAXNUM3", SDTFPTernaryOp,
+                             [SDNPCommutative, SDNPAssociative]>;
+def nvptx_fminimum3 : SDNode<"NVPTXISD::FMINIMUM3", SDTFPTernaryOp,
+                             [SDNPCommutative, SDNPAssociative]>;
+def nvptx_fmaximum3 : SDNode<"NVPTXISD::FMAXIMUM3", SDTFPTernaryOp,
+                             [SDNPCommutative, SDNPAssociative]>;
+
+defm FMIN3 : FMINIMUMMAXIMUM3<"min", nvptx_fminnum3>;
+defm FMAX3 : FMINIMUMMAXIMUM3<"max", nvptx_fmaxnum3>;
+defm FMINNAN3 : FMINIMUMMAXIMUM3<"min.NaN", nvptx_fminimum3>;
+defm FMAXNAN3 : FMINIMUMMAXIMUM3<"max.NaN", nvptx_fmaximum3>;
+
 defm FABS  : F2<"abs", fabs>;
 defm FNEG  : F2<"neg", fneg>;
 defm FABS_H: F2_Support_Half<"abs", fabs>;
diff --git a/llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.h b/llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.h
index 9e77f628da7a7..2cc81a064152f 100644
--- a/llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.h
+++ b/llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.h
@@ -83,6 +83,8 @@ class NVPTXTTIImpl : public BasicTTIImplBase<NVPTXTTIImpl> {
   }
   unsigned getMinVectorRegisterBitWidth() const { return 32; }
 
+  bool shouldExpandReduction(const IntrinsicInst *II) const { return false; }
+
   // We don't want to prevent inlining because of target-cpu and -features
   // attributes that were added to newer versions of LLVM/Clang: There are
   // no incompatible functions in PTX, ptxas will throw errors in such cases.
diff --git a/llvm/test/CodeGen/NVPTX/reduction-intrinsics.ll b/llvm/test/CodeGen/NVPTX/reduction-intrinsics.ll
new file mode 100644
index 0000000000000..a9101ba3ca651
--- /dev/null
+++ b/llvm/test/CodeGen/NVPTX/reduction-intrinsics.ll
@@ -0,0 +1,1908 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --extra_scrub --version 5
+; RUN: llc < %s -mcpu=sm_80 -mattr=+ptx70 -O0 \
+; RUN:      -disable-post-ra -verify-machineinstrs \
+; RUN: | FileCheck -check-prefixes CHECK,CHECK-SM80 %s
+; RUN: %if ptxas-12.9 %{ llc < %s -mcpu=sm_80 -mattr=+ptx70 -O0 \
+; RUN:      -disable-post-ra -verify-machineinstrs \
+; RUN: | %ptxas-verify -arch=sm_80 %}
+; RUN: llc < %s -mcpu=sm_100 -mattr=+ptx88 -O0 \
+; RUN:      -disable-post-ra -verify-machineinstrs \
+; RUN: | FileCheck -check-prefixes CHECK,CHECK-SM100 %s
+; RUN: %if ptxas-12.9 %{ llc < %s -mcpu=sm_100 -mattr=+ptx88 -O0 \
+; RUN:      -disable-post-ra -verify-machineinstrs \
+; RUN: | %ptxas-verify -arch=sm_100 %}
+target triple = "nvptx64-nvidia-cuda"
+target datalayout = "e-m:o-i64:64-i128:128-n32:64-S128"
+
+; Check straight line reduction.
+define half @reduce_fadd_half(<8 x half> %in) {
+; CHECK-LABEL: reduce_fadd_half(
+; CHECK:       {
+; CHECK-NEXT:    .reg .b16 %rs<18>;
+; CHECK-NEXT:    .reg .b32 %r<5>;
+; CHECK-EMPTY:
+; CHECK-NEXT:  // %bb.0:
+; CHECK-NEXT:    ld.param.v4.u32 {%r1, %r2, %r3, %r4}, [reduce_fadd_half_param_0];
+; CHECK-NEXT:    mov.b32 {%rs1, %rs2}, %r1;
+; CHECK-NEXT:    mov.b16 %rs3, 0x0000;
+; CHECK-NEXT:    add.rn.f16 %rs4, %rs1, %rs3;
+; CHECK-NEXT:    add.rn.f16 %rs5, %rs4, %rs2;
+; CHECK-NEXT:    mov.b32 {%rs6, %rs7}, %r2;
+; CHECK-NEXT:    add.rn.f16 %rs8, %rs5, %rs6;
+; CHECK-NEXT:    add.rn.f16 %rs9, %rs8, %rs7;
+; CHECK-NEXT:    mov.b32 {%rs10, %rs11}, %r3;
+; CHECK-NEXT:    add.rn.f16 %rs12, %rs9, %rs10;
+; CHECK-NEXT:    add.rn.f16 %rs13, %rs12, %rs11;
+; CHECK-NEXT:    mov.b32 {%rs14, %rs15}, %r4;
+; CHECK-NEXT:    add.rn.f16 %rs16, %rs13, %rs14;
+; CHECK-NEXT:    add.rn.f16 %rs17, %rs16, %rs15;
+; CHECK-NEXT:    st.param.b16 [func_retval0], %rs17;
+; CHECK-NEXT:    ret;
+  %res = call half @llvm.vector.reduce.fadd(half 0.0, <8 x half> %in)
+  ret half %res
+}
+
+; Check tree reduction.
+define half @reduce_fadd_half_reassoc(<8 x half> %in) {
+; CHECK-LABEL: reduce_fadd_half_reassoc(
+; CHECK:       {
+; CHECK-NEXT:    .reg .b16 %rs<18>;
+; CHECK-NEXT:    .reg .b32 %r<5>;
+; CHECK-EMPTY:
+; CHECK-NEXT:  // %bb.0:
+; CHECK-NEXT:    ld.param.v4.u32 {%r1, %r2, %r3, %r4}, [reduce_fadd_half_reassoc_param_0];
+; CHECK-NEXT:    mov.b32 {%rs1, %rs2}, %r4;
+; CHECK-NEXT:    add.rn.f16 %rs3, %rs1, %rs2;
+; CHECK-NEXT:    mov.b32 {%rs4, %rs5}, %r3;
+; CHECK-NEXT:    add.rn.f16 %rs6, %rs4, %rs5;
+; CHECK-NEXT:    add.rn.f16 %rs7, %rs6, %rs3;
+; CHECK-NEXT:    mov.b32 {%rs8, %rs9}, %r2;
+; CHECK-NEXT:    add.rn.f16 %rs10, %rs8, %rs9;
+; CHECK-NEXT:    mov.b32 {%rs11, %rs12}, %r1;
+; CHECK-NEXT:    add.rn.f16 %rs13, %rs11, %rs12;
+; CHECK-NEXT:    add.rn.f16 %rs14, %rs13, %rs10;
+; CHECK-NEXT:    add.rn.f16 %rs15, %rs14, %rs7;
+; CHECK-NEXT:    mov.b16 %rs16, 0x0000;
+; CHECK-NEXT:    add.rn.f16 %rs17, %rs15, %rs16;
+; CHECK-NEXT:    st.param.b16 [func_retval0], %rs17;
+; CHECK-NEXT:    ret;
+  %res = call reassoc half @llvm.vector.reduce.fadd(half 0.0, <8 x half> %in)
+  ret half %res
+}
+
+; Check tree reduction with non-power of 2 size.
+define half @reduce_fadd_half_reassoc_nonpow2(<7 x half> %in) {
+; CHECK-LABEL: reduce_fadd_half_reassoc_nonpow2(
+; CHECK:       {
+; CHECK-NEXT:    .reg .b16 %rs<16>;
+; CHECK-NEXT:    .reg .b32 %r<2>;
+; CHECK-EMPTY:
+; CHECK-NEXT:  // %bb.0:
+; CHECK-NEXT:    ld.param.b32 %r1, [reduce_fadd_half_reassoc_nonpow2_param_0+8];
+; CHECK-NEXT:    mov.b32 {%rs5, %rs6}, %r1;
+; CHECK-NEXT:    ld.param.b16 %rs7, [reduce_fadd_half_reassoc_nonpow2_param_0+12];
+; CHECK-NEXT:    ld.param....
[truncated]

@Prince781 Prince781 force-pushed the dev/pferro/nvptx-vector-reduce branch from 1f1348c to 4e481f8 Compare April 18, 2025 04:42
@Prince781 Prince781 force-pushed the dev/pferro/nvptx-vector-reduce branch 2 times, most recently from 1ea7992 to c895233 Compare April 23, 2025 04:43
@Prince781 Prince781 force-pushed the dev/pferro/nvptx-vector-reduce branch from c895233 to b13ed82 Compare May 22, 2025 11:10
@Prince781
Copy link
Contributor Author

After some thought, it seems all three patterns of lowering a reduction have their tradeoffs and each is preferred in some circumstance. I'd like to propose keeping the default lowering (shuffle reduction), but add a way for clients to override this with metadata. I'm thinking about something like !reduce.tree on the reduction call. Before I proceed, I'd like to know if this is a plausible idea and if I should propose it in an RFC. Thanks!

@Prince781
Copy link
Contributor Author

Prince781 commented Jun 11, 2025

Proposed use:

; default (shuffle reduction for f16x2 and f32x2)
define float @reduce_fadd_reassoc(<16 x float> %in) {
  %res = call reassoc float @llvm.vector.reduce.fadd(float 0.0, <16 x float> %in)
  ret float %res
}

; default (tree reduction on SM100 for fmin3; shuffle reduction otherwise)
define float @reduce_fmin_reassoc(<16 x float> %in) {
  %res = call reassoc float @llvm.vector.reduce.fmin(float 0.0, <16 x float> %in)
  ret float %res
}

; force tree for reassoc float
define float @reduce_fadd_reassoc(<16 x float> %in) {
  %res = call reassoc float @llvm.vector.reduce.fadd(float 0.0, <16 x float> %in) !reduce.tree
  ret float %res
}

; force shuffle
define float @reduce_fmin_reassoc(<16 x float> %in) {
  %res = call reassoc float @llvm.vector.reduce.fmin(float 0.0, <16 x float> %in) !reduce.shuffle
  ret float %res
}

; force sequential
define i32 @reduce_umin(<16 x i32> %in) {
  %res = call i32 @llvm.vector.reduce.umin(<16 x i32> %in) !reduce.sequential
  ret i32 %res
}

@Prince781 Prince781 force-pushed the dev/pferro/nvptx-vector-reduce branch 2 times, most recently from 24925dc to a22c141 Compare June 19, 2025 04:49
Copy link

github-actions bot commented Jun 19, 2025

✅ With the latest revision this PR passed the C/C++ code formatter.

@Prince781
Copy link
Contributor Author

Actually, I'm going to move the proposal of metadata / attributes affecting lowering decisions to another discussion in a future PR or RFC.

I've just updated the PR. We'll use the default shuffle reduction in SelectionDAG when packed ops are available on the target for the element type (ex: v2f16, v2bf16, v2i16). Otherwise we'll use tree or sequential reductions, depending on the fast-math option or reassoc flag, and whether there are special operations available like fmin3 / fmax3.

This allows us to keep using packed operations like max.f16x2 (which use less registers) while switching to tree reduction in every other case where we can reassociate operands. We're deciding this based on the target's support for the element type with the operation, so I think it makes sense to handle these intrinsics at the SelectionDAG level and not at the IR level inside shouldExpandReductions().

I also notice that because we now fallback to the shuffle reduction generated by SelectionDAG instead of ExpandReductions, the codegen is cleaner and the last packed operation is scalarized. So this PR may supersede #143943

@Prince781 Prince781 force-pushed the dev/pferro/nvptx-vector-reduce branch 6 times, most recently from ffd4140 to 5fbe8bc Compare June 25, 2025 00:42
@Prince781 Prince781 force-pushed the dev/pferro/nvptx-vector-reduce branch 2 times, most recently from d08903f to 7a4de91 Compare July 1, 2025 18:06
@Prince781 Prince781 requested review from AlexMaclean and Artem-B July 8, 2025 21:23
@Prince781 Prince781 force-pushed the dev/pferro/nvptx-vector-reduce branch 2 times, most recently from 72554b0 to 77a5de0 Compare July 21, 2025 08:09
@Prince781 Prince781 requested a review from AlexMaclean July 21, 2025 08:17
Comment on lines 90 to 95
bool shouldExpandReduction(const IntrinsicInst *II) const override {
return false;
}
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FYI: this change is what seems to improve the codegen for non-f32 cases by letting SelectionDAG handle these intrinsics.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to explicitly allow expansion on the nodes we can't lower to the PTX ops?
What's the default for non-FP vector reduce ops?
Perhaps it would make more sense to explicitly mark FP vector reductions as legal or custom, so it's clear that we're doing it intentionally.

Copy link
Contributor Author

@Prince781 Prince781 Jul 29, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The default was to lower all reductions to a shuffle reduction in LLVM IR like so:

define i32 @reduce(<4 x i32> %v) {
  %v.half = shufflevector <4 x i32> %v, <4 x i32> poison, <i32 2, i32 3, i32 undef, i32 undef> ; %half.1 = <%v[2], %v[3], undef, undef>
  %res.1 = fadd <4 x i32> %v, %half.1 ; = <%v[0]+%v[2], %v[1]+%v[3], undef, undef>
  %res.1.half = shufflevector <4 x i32> %res.1, <4 x i32> poison, <i32 3, i32 undef, i32 undef, i32 undef> ; = <%v[1]+%v[3], undef, undef, undef>
  %res.2 = fadd <4 x i32> %res.1, %res.1.half ; = <%v[0]+%v[2]+%v[1]+%v[3], undef, undef undef>
  %result = extractelement <4 x i32> %res.2, i32 0 ; = %v[0]+%v[2]+%v[1]+%v[3]
  ret i32 %result
}

This happens after the ExpandReductions pass. Later on SelectionDAG would handle these shufflevectors in a particular way that would lead to the problem described in #143943

When I tried to fix the handling of shuffle reductions in DAGCombiner, I found it harmed codegen on other backends, so I think ExpandReductions is useful for some backends but not NVPTX because we don't have advanced swizzling instructions like other targets have. All we have are mov and prmt (which only works for b32).

So I think it's best to turn off this pass completely.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Prince781 Prince781 force-pushed the dev/pferro/nvptx-vector-reduce branch 2 times, most recently from bc883b6 to cca2a27 Compare July 29, 2025 02:41
Add support for 3-input fmaxnum/fminnum/fmaximum/fminimum introduced in
PTX 8.8 for sm_100+. Other improvements are:
- Fix lowering of fmaxnum/fminnum so that they are not reassociated.
  According to the IEEE 754 definition of these (maxNum/minNum), they are not
  associative due to how they handle sNaNs. A quick example:
  a = 1.0, b = 1.0, c = sNaN
  maxNum(a, maxNum(b, c)) = maxNum(a, qNaN) = a = 1.0
  maxNum(maxNum(a, b), c) = maxNum(1.0, sNaN) = qNaN
- Use a tree reduction when 3-input operations are supported and the
  reduction has the `reassoc`.
- If not on sm_100+/PTX 8.8, fallback to 2-input operations and use the
  default shuffle reduction.
@Prince781 Prince781 force-pushed the dev/pferro/nvptx-vector-reduce branch from cca2a27 to f4f7fe1 Compare July 29, 2025 02:49
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants