Skip to content

[NVPTX] support packed f32 instructions for sm_100+ #126337

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jul 11, 2025

Conversation

Prince781
Copy link
Contributor

@Prince781 Prince781 commented Feb 8, 2025

This adds support for lowering fadd, fsub, fmul, and fma to sm_100+ packed-f32 instructions1 (e.g. add.rn.f32x2 Int64Reg, Int64Reg). Rather than legalizing v2f32, we handle these four instructions ad hoc, so that codegen remains the same unless these instructions are present. We also introduce some DAGCombiner rules to simplify bitwise packing/unpacking to use mov, and to reduce redundant movs.

In this PR I didn't implement support for alternative rounding modes, as that was lower priority. If there's sufficient demand, I can add that to this PR. Otherwise we can leave that for later.

Footnotes

  1. Introduced in PTX 8.6: https://docs.nvidia.com/cuda/parallel-thread-execution/#changes-in-ptx-isa-version-8-6

Copy link

github-actions bot commented Feb 8, 2025

✅ With the latest revision this PR passed the C/C++ code formatter.

@llvmbot
Copy link
Member

llvmbot commented Feb 8, 2025

@llvm/pr-subscribers-backend-nvptx

Author: Princeton Ferro (Prince781)

Changes

This adds support for lowering fadd, fsub, fmul, and fma to sm_100+ packed-f32 instructions1 (e.g. add.rn.f32x2 Int64Reg, Int64Reg). Rather than legalizing v2f32, we handle these four instructions ad hoc, so that codegen remains the same unless these instructions are present. We also introduce some DAGCombiner rules to simplify bitwise packing/unpacking to use mov, and to reduce redundant movs.

In this PR I didn't implement support for alternative rounding modes, as that was lower priority. If there's sufficient demand, I can add that to this PR. Otherwise we can leave that for later.


Patch is 127.41 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/126337.diff

10 Files Affected:

  • (modified) llvm/include/llvm/Target/TargetSelectionDAG.td (+7)
  • (modified) llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp (+19)
  • (modified) llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.h (+1)
  • (modified) llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp (+275-9)
  • (modified) llvm/lib/Target/NVPTX/NVPTXISelLowering.h (+4)
  • (modified) llvm/lib/Target/NVPTX/NVPTXInstrInfo.td (+14-2)
  • (modified) llvm/lib/Target/NVPTX/NVPTXIntrinsics.td (+30)
  • (modified) llvm/lib/Target/NVPTX/NVPTXRegisterInfo.td (+3-1)
  • (modified) llvm/lib/Target/NVPTX/NVPTXSubtarget.h (+3)
  • (added) llvm/test/CodeGen/NVPTX/f32x2-instructions.ll (+2665)
diff --git a/llvm/include/llvm/Target/TargetSelectionDAG.td b/llvm/include/llvm/Target/TargetSelectionDAG.td
index 42a5fbec95174e..394428594b9870 100644
--- a/llvm/include/llvm/Target/TargetSelectionDAG.td
+++ b/llvm/include/llvm/Target/TargetSelectionDAG.td
@@ -115,6 +115,9 @@ def SDTPtrAddOp : SDTypeProfile<1, 2, [     // ptradd
 def SDTIntBinOp : SDTypeProfile<1, 2, [     // add, and, or, xor, udiv, etc.
   SDTCisSameAs<0, 1>, SDTCisSameAs<0, 2>, SDTCisInt<0>
 ]>;
+def SDTIntTernaryOp : SDTypeProfile<1, 3, [  // fma32x2
+  SDTCisSameAs<0, 1>, SDTCisSameAs<0, 2>, SDTCisSameAs<0, 3>, SDTCisInt<0>
+]>;
 def SDTIntShiftOp : SDTypeProfile<1, 2, [   // shl, sra, srl
   SDTCisSameAs<0, 1>, SDTCisInt<0>, SDTCisInt<2>
 ]>;
@@ -818,6 +821,10 @@ def step_vector : SDNode<"ISD::STEP_VECTOR", SDTypeProfile<1, 1,
 def scalar_to_vector : SDNode<"ISD::SCALAR_TO_VECTOR", SDTypeProfile<1, 1, []>,
                               []>;
 
+def build_pair : SDNode<"ISD::BUILD_PAIR", SDTypeProfile<1, 2,
+                        [SDTCisInt<0>, SDTCisInt<1>, SDTCisInt<2>]>, []>;
+
+
 // vector_extract/vector_insert are deprecated. extractelt/insertelt
 // are preferred.
 def vector_extract : SDNode<"ISD::EXTRACT_VECTOR_ELT",
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp
index ec654e0f3f200f..3a39f6dab0c85f 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp
@@ -190,6 +190,12 @@ void NVPTXDAGToDAGISel::Select(SDNode *N) {
       SelectI128toV2I64(N);
       return;
     }
+    if (N->getOperand(1).getValueType() == MVT::i64 &&
+        N->getValueType(0) == MVT::f32 && N->getValueType(1) == MVT::f32) {
+      // {f32,f32} = mov i64
+      SelectI64ToV2F32(N);
+      return;
+    }
     break;
   }
   case ISD::FADD:
@@ -2765,6 +2771,19 @@ void NVPTXDAGToDAGISel::SelectI128toV2I64(SDNode *N) {
   ReplaceNode(N, Mov);
 }
 
+void NVPTXDAGToDAGISel::SelectI64ToV2F32(SDNode *N) {
+  SDValue Ch = N->getOperand(0);
+  SDValue Src = N->getOperand(1);
+  assert(N->getValueType(0) == MVT::f32 && N->getValueType(1) == MVT::f32 &&
+         "expected {f32,f32} = CopyFromReg i64");
+  SDLoc DL(N);
+
+  SDNode *Mov = CurDAG->getMachineNode(NVPTX::I64toV2F32, DL,
+                                       {MVT::f32, MVT::f32, Ch.getValueType()},
+                                       {Src, Ch});
+  ReplaceNode(N, Mov);
+}
+
 /// GetConvertOpcode - Returns the CVT_ instruction opcode that implements a
 /// conversion from \p SrcTy to \p DestTy.
 unsigned NVPTXDAGToDAGISel::GetConvertOpcode(MVT DestTy, MVT SrcTy,
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.h b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.h
index 8dc6bc86c68281..703a80f74e90c7 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.h
+++ b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.h
@@ -91,6 +91,7 @@ class LLVM_LIBRARY_VISIBILITY NVPTXDAGToDAGISel : public SelectionDAGISel {
   bool tryEXTRACT_VECTOR_ELEMENT(SDNode *N);
   void SelectV2I64toI128(SDNode *N);
   void SelectI128toV2I64(SDNode *N);
+  void SelectI64ToV2F32(SDNode *N);
   void SelectCpAsyncBulkG2S(SDNode *N);
   void SelectCpAsyncBulkS2G(SDNode *N);
   void SelectCpAsyncBulkPrefetchL2(SDNode *N);
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
index 58ad92a8934a66..1e417f23fdb099 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
@@ -866,6 +866,24 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
   setBF16OperationAction(ISD::FNEG, MVT::v2bf16, Legal, Expand);
   // (would be) Library functions.
 
+  if (STI.hasF32x2Instructions()) {
+    // Handle custom lowering for: v2f32 = OP v2f32, v2f32
+    for (const auto &Op : {ISD::FADD, ISD::FSUB, ISD::FMUL, ISD::FMA})
+      setOperationAction(Op, MVT::v2f32, Custom);
+    // Handle custom lowering for: f32 = extract_vector_elt v2f32
+    setOperationAction(ISD::EXTRACT_VECTOR_ELT, MVT::v2f32, Custom);
+    // Combine:
+    // i64 = or (i64 = zero_extend X, i64 = shl (i64 = any_extend Y, 32))
+    // -> i64 = build_pair (X, Y)
+    setTargetDAGCombine(ISD::OR);
+    // i32 = truncate (i64 = srl (i64 = build_pair (X, Y), 32))
+    // -> i32 Y
+    setTargetDAGCombine(ISD::TRUNCATE);
+    // i64 = build_pair ({i32, i32} = CopyFromReg (CopyToReg (i64 X)))
+    // -> i64 X
+    setTargetDAGCombine(ISD::BUILD_PAIR);
+  }
+
   // These map to conversion instructions for scalar FP types.
   for (const auto &Op : {ISD::FCEIL, ISD::FFLOOR, ISD::FNEARBYINT, ISD::FRINT,
                          ISD::FROUNDEVEN, ISD::FTRUNC}) {
@@ -1066,6 +1084,10 @@ const char *NVPTXTargetLowering::getTargetNodeName(unsigned Opcode) const {
     MAKE_CASE(NVPTXISD::STACKSAVE)
     MAKE_CASE(NVPTXISD::SETP_F16X2)
     MAKE_CASE(NVPTXISD::SETP_BF16X2)
+    MAKE_CASE(NVPTXISD::FADD_F32X2)
+    MAKE_CASE(NVPTXISD::FSUB_F32X2)
+    MAKE_CASE(NVPTXISD::FMUL_F32X2)
+    MAKE_CASE(NVPTXISD::FMA_F32X2)
     MAKE_CASE(NVPTXISD::Dummy)
     MAKE_CASE(NVPTXISD::MUL_WIDE_SIGNED)
     MAKE_CASE(NVPTXISD::MUL_WIDE_UNSIGNED)
@@ -2207,6 +2229,30 @@ SDValue NVPTXTargetLowering::LowerEXTRACT_VECTOR_ELT(SDValue Op,
     return DAG.getAnyExtOrTrunc(BFE, DL, Op->getValueType(0));
   }
 
+  if (VectorVT == MVT::v2f32) {
+    auto GetOperand = [&DAG, &DL](SDValue Op, SDValue Index) {
+      if (const auto *ConstIdx = dyn_cast<ConstantSDNode>(Index))
+        return Op.getOperand(ConstIdx->getZExtValue());
+      SDValue E0 = Op.getOperand(0);
+      SDValue E1 = Op.getOperand(1);
+      return DAG.getSelectCC(DL, Index, DAG.getIntPtrConstant(0, DL), E0, E1,
+                             ISD::CondCode::SETEQ);
+    };
+    if (SDValue Pair = Vector.getOperand(0);
+        Vector.getOpcode() == ISD::BITCAST &&
+        Pair.getOpcode() == ISD::BUILD_PAIR) {
+      // peek through v2f32 = bitcast (i64 = build_pair (i32 A, i32 B))
+      // where A:i32, B:i32 = CopyFromReg (i64 = F32X2 Operation ...)
+      return DAG.getNode(ISD::BITCAST, DL, Op.getValueType(),
+                         GetOperand(Pair, Index));
+    }
+    if (Vector.getOpcode() == ISD::BUILD_VECTOR)
+      return GetOperand(Vector, Index);
+
+    // Otherwise, let SelectionDAG expand the operand.
+    return SDValue();
+  }
+
   // Constant index will be matched by tablegen.
   if (isa<ConstantSDNode>(Index.getNode()))
     return Op;
@@ -4573,26 +4619,109 @@ PerformFADDCombineWithOperands(SDNode *N, SDValue N0, SDValue N1,
   return SDValue();
 }
 
+// If {Lo, Hi} = <packed f32x2 val>, returns that value
+static SDValue peekThroughF32x2Copy(const SDValue &Lo, const SDValue &Hi) {
+  if (Lo.getValueType() != MVT::f32 || Lo.getOpcode() != ISD::CopyFromReg ||
+      Lo.getNode() != Hi.getNode() || Lo == Hi)
+    return SDValue();
+
+  SDNode *CopyF = Lo.getNode();
+  SDNode *CopyT = CopyF->getOperand(0).getNode();
+  if (CopyT->getOpcode() != ISD::CopyToReg)
+    return SDValue();
+
+  // check the two registers are the same
+  if (cast<RegisterSDNode>(CopyF->getOperand(1))->getReg() !=
+      cast<RegisterSDNode>(CopyT->getOperand(1))->getReg())
+    return SDValue();
+
+  SDValue OrigV = CopyT->getOperand(2);
+  if (OrigV.getValueType() != MVT::i64)
+    return SDValue();
+  return OrigV;
+}
+
+static SDValue
+PerformPackedF32StoreCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI,
+                             CodeGenOptLevel OptLevel) {
+  if (OptLevel == CodeGenOptLevel::None)
+    return SDValue();
+
+  // rewrite stores of packed f32 values
+  auto *MemN = cast<MemSDNode>(N);
+  if (MemN->getMemoryVT() == MVT::f32) {
+    std::optional<NVPTXISD::NodeType> NewOpcode;
+    switch (MemN->getOpcode()) {
+    case NVPTXISD::StoreRetvalV2:
+      NewOpcode = NVPTXISD::StoreRetval;
+      break;
+    case NVPTXISD::StoreRetvalV4:
+      NewOpcode = NVPTXISD::StoreRetvalV2;
+      break;
+    case NVPTXISD::StoreParamV2:
+      NewOpcode = NVPTXISD::StoreParam;
+      break;
+    case NVPTXISD::StoreParamV4:
+      NewOpcode = NVPTXISD::StoreParamV2;
+      break;
+    }
+
+    if (NewOpcode) {
+      SmallVector<SDValue> NewOps = {N->getOperand(0), N->getOperand(1)};
+      unsigned NumPacked = 0;
+
+      // gather all packed operands
+      for (unsigned I = 2, E = MemN->getNumOperands(); I < E; I += 2) {
+        if (SDValue Packed = peekThroughF32x2Copy(MemN->getOperand(I),
+                                                  MemN->getOperand(I + 1))) {
+          NewOps.push_back(Packed);
+          ++NumPacked;
+        } else {
+          NumPacked = 0;
+          break;
+        }
+      }
+
+      if (NumPacked) {
+        return DCI.DAG.getMemIntrinsicNode(
+            *NewOpcode, SDLoc(N), N->getVTList(), NewOps, MVT::i64,
+            MemN->getPointerInfo(), MemN->getAlign(),
+            MachineMemOperand::MOStore);
+      }
+    }
+  }
+  return SDValue();
+}
+
 static SDValue PerformStoreCombineHelper(SDNode *N, std::size_t Front,
-                                         std::size_t Back) {
+                                         std::size_t Back,
+                                         TargetLowering::DAGCombinerInfo &DCI,
+                                         CodeGenOptLevel OptLevel) {
   if (all_of(N->ops().drop_front(Front).drop_back(Back),
              [](const SDUse &U) { return U.get()->isUndef(); }))
     // Operand 0 is the previous value in the chain. Cannot return EntryToken
     // as the previous value will become unused and eliminated later.
     return N->getOperand(0);
 
+  if (SDValue V = PerformPackedF32StoreCombine(N, DCI, OptLevel))
+    return V;
+
   return SDValue();
 }
 
-static SDValue PerformStoreParamCombine(SDNode *N) {
+static SDValue PerformStoreParamCombine(SDNode *N,
+                                        TargetLowering::DAGCombinerInfo &DCI,
+                                        CodeGenOptLevel OptLevel) {
   // Operands from the 3rd to the 2nd last one are the values to be stored.
   //   {Chain, ArgID, Offset, Val, Glue}
-  return PerformStoreCombineHelper(N, 3, 1);
+  return PerformStoreCombineHelper(N, 3, 1, DCI, OptLevel);
 }
 
-static SDValue PerformStoreRetvalCombine(SDNode *N) {
+static SDValue PerformStoreRetvalCombine(SDNode *N,
+                                         TargetLowering::DAGCombinerInfo &DCI,
+                                         CodeGenOptLevel OptLevel) {
   // Operands from the 2nd to the last one are the values to be stored
-  return PerformStoreCombineHelper(N, 2, 0);
+  return PerformStoreCombineHelper(N, 2, 0, DCI, OptLevel);
 }
 
 /// PerformADDCombine - Target-specific dag combine xforms for ISD::ADD.
@@ -5055,10 +5184,10 @@ static SDValue PerformEXTRACTCombine(SDNode *N,
       IsPTXVectorType(VectorVT.getSimpleVT()))
     return SDValue(); // Native vector loads already combine nicely w/
                       // extract_vector_elt.
-  // Don't mess with singletons or v2*16, v4i8 and v8i8 types, we already
+  // Don't mess with singletons or v2*16, v4i8, v8i8, or v2f32 types, we already
   // handle them OK.
   if (VectorVT.getVectorNumElements() == 1 || Isv2x16VT(VectorVT) ||
-      VectorVT == MVT::v4i8 || VectorVT == MVT::v8i8)
+      VectorVT == MVT::v4i8 || VectorVT == MVT::v8i8 || VectorVT == MVT::v2f32)
     return SDValue();
 
   // Don't mess with undef values as sra may be simplified to 0, not undef.
@@ -5188,6 +5317,78 @@ PerformBUILD_VECTORCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI) {
   return DAG.getNode(ISD::BITCAST, DL, VT, PRMT);
 }
 
+static SDValue PerformORCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI,
+                                CodeGenOptLevel OptLevel) {
+  if (OptLevel == CodeGenOptLevel::None)
+    return SDValue();
+
+  SDValue Op0 = N->getOperand(0);
+  SDValue Op1 = N->getOperand(1);
+
+  // i64 = or (i64 = zero_extend A, i64 = shl (i64 = any_extend B, 32))
+  // -> i64 = build_pair (A, B)
+  if (N->getValueType(0) == MVT::i64 && Op0.getOpcode() == ISD::ZERO_EXTEND &&
+      Op1.getOpcode() == ISD::SHL) {
+    SDValue SHLOp0 = Op1.getOperand(0);
+    SDValue SHLOp1 = Op1.getOperand(1);
+    if (const auto *Const = dyn_cast<ConstantSDNode>(SHLOp1);
+        Const && Const->getZExtValue() == 32 &&
+        SHLOp0.getOpcode() == ISD::ANY_EXTEND) {
+      SDLoc DL(N);
+      return DCI.DAG.getNode(ISD::BUILD_PAIR, DL, MVT::i64,
+                             {Op0.getOperand(0), SHLOp0.getOperand(0)});
+    }
+  }
+  return SDValue();
+}
+
+static SDValue PerformTRUNCATECombine(SDNode *N,
+                                      TargetLowering::DAGCombinerInfo &DCI,
+                                      CodeGenOptLevel OptLevel) {
+  if (OptLevel == CodeGenOptLevel::None)
+    return SDValue();
+
+  SDValue Op = N->getOperand(0);
+  if (Op.getOpcode() == ISD::SRL) {
+    SDValue SrlOp = Op.getOperand(0);
+    SDValue SrlSh = Op.getOperand(1);
+    // i32 = truncate (i64 = srl (i64 build_pair (A, B), 32))
+    // -> i32 A
+    if (const auto *Const = dyn_cast<ConstantSDNode>(SrlSh);
+        Const && Const->getZExtValue() == 32) {
+      if (SrlOp.getOpcode() == ISD::BUILD_PAIR)
+        return SrlOp.getOperand(1);
+    }
+  }
+
+  return SDValue();
+}
+
+static SDValue PerformBUILD_PAIRCombine(SDNode *N,
+                                        TargetLowering::DAGCombinerInfo &DCI,
+                                        CodeGenOptLevel OptLevel) {
+  if (OptLevel == CodeGenOptLevel::None)
+    return SDValue();
+
+  EVT ToVT = N->getValueType(0);
+  SDValue Op0 = N->getOperand(0);
+  SDValue Op1 = N->getOperand(1);
+  // i64 = build_pair ({i32, i32} = CopyFromReg (CopyToReg (i64 X)))
+  // -> i64 X
+  if (ToVT == MVT::i64 && Op0.getOpcode() == ISD::CopyFromReg &&
+      Op1.getNode() == Op0.getNode() && Op0 != Op1) {
+    SDValue CFRChain = Op0.getOperand(0);
+    Register Reg = cast<RegisterSDNode>(Op0.getOperand(1))->getReg();
+    if (CFRChain.getOpcode() == ISD::CopyToReg &&
+        cast<RegisterSDNode>(CFRChain.getOperand(1))->getReg() == Reg) {
+      SDValue Value = CFRChain.getOperand(2);
+      return Value;
+    }
+  }
+
+  return SDValue();
+}
+
 SDValue NVPTXTargetLowering::PerformDAGCombine(SDNode *N,
                                                DAGCombinerInfo &DCI) const {
   CodeGenOptLevel OptLevel = getTargetMachine().getOptLevel();
@@ -5211,17 +5412,23 @@ SDValue NVPTXTargetLowering::PerformDAGCombine(SDNode *N,
     case NVPTXISD::StoreRetval:
     case NVPTXISD::StoreRetvalV2:
     case NVPTXISD::StoreRetvalV4:
-      return PerformStoreRetvalCombine(N);
+      return PerformStoreRetvalCombine(N, DCI, OptLevel);
     case NVPTXISD::StoreParam:
     case NVPTXISD::StoreParamV2:
     case NVPTXISD::StoreParamV4:
-      return PerformStoreParamCombine(N);
+      return PerformStoreParamCombine(N, DCI, OptLevel);
     case ISD::EXTRACT_VECTOR_ELT:
       return PerformEXTRACTCombine(N, DCI);
     case ISD::VSELECT:
       return PerformVSELECTCombine(N, DCI);
     case ISD::BUILD_VECTOR:
       return PerformBUILD_VECTORCombine(N, DCI);
+    case ISD::OR:
+      return PerformORCombine(N, DCI, OptLevel);
+    case ISD::TRUNCATE:
+      return PerformTRUNCATECombine(N, DCI, OptLevel);
+    case ISD::BUILD_PAIR:
+      return PerformBUILD_PAIRCombine(N, DCI, OptLevel);
   }
   return SDValue();
 }
@@ -5478,6 +5685,59 @@ static void ReplaceCopyFromReg_128(SDNode *N, SelectionDAG &DAG,
   Results.push_back(NewValue.getValue(3));
 }
 
+static void ReplaceF32x2Op(SDNode *N, SelectionDAG &DAG,
+                           SmallVectorImpl<SDValue> &Results) {
+  SDLoc DL(N);
+  EVT OldResultTy = N->getValueType(0); // <2 x float>
+  assert(OldResultTy == MVT::v2f32 && "Unexpected result type for F32x2 op!");
+
+  SmallVector<SDValue> NewOps;
+
+  // whether we use FTZ (TODO)
+
+  // replace with NVPTX F32x2 op:
+  unsigned Opcode;
+  switch (N->getOpcode()) {
+  case ISD::FADD:
+    Opcode = NVPTXISD::FADD_F32X2;
+    break;
+  case ISD::FSUB:
+    Opcode = NVPTXISD::FSUB_F32X2;
+    break;
+  case ISD::FMUL:
+    Opcode = NVPTXISD::FMUL_F32X2;
+    break;
+  case ISD::FMA:
+    Opcode = NVPTXISD::FMA_F32X2;
+    break;
+  default:
+    llvm_unreachable("Unexpected opcode");
+  }
+
+  // bitcast operands: <2 x float> -> i64
+  for (const SDValue &Op : N->ops())
+    NewOps.push_back(DAG.getNode(ISD::BITCAST, DL, MVT::i64, Op));
+
+  SDValue Chain = DAG.getEntryNode();
+
+  // break packed result into two f32 registers for later instructions that may
+  // access element #0 or #1
+  SDValue NewValue = DAG.getNode(Opcode, DL, MVT::i64, NewOps);
+  MachineRegisterInfo &RegInfo = DAG.getMachineFunction().getRegInfo();
+  Register DestReg = RegInfo.createVirtualRegister(
+      DAG.getTargetLoweringInfo().getRegClassFor(MVT::i64));
+  SDValue RegCopy = DAG.getCopyToReg(Chain, DL, DestReg, NewValue);
+  SDValue Explode = DAG.getNode(ISD::CopyFromReg, DL,
+                                {MVT::f32, MVT::f32, Chain.getValueType()},
+                                {RegCopy, DAG.getRegister(DestReg, MVT::i64)});
+  // cast i64 result of new op back to <2 x float>
+  Results.push_back(DAG.getBitcast(
+      OldResultTy,
+      DAG.getNode(ISD::BUILD_PAIR, DL, MVT::i64,
+                  {DAG.getBitcast(MVT::i32, Explode.getValue(0)),
+                   DAG.getBitcast(MVT::i32, Explode.getValue(1))})));
+}
+
 void NVPTXTargetLowering::ReplaceNodeResults(
     SDNode *N, SmallVectorImpl<SDValue> &Results, SelectionDAG &DAG) const {
   switch (N->getOpcode()) {
@@ -5495,6 +5755,12 @@ void NVPTXTargetLowering::ReplaceNodeResults(
   case ISD::CopyFromReg:
     ReplaceCopyFromReg_128(N, DAG, Results);
     return;
+  case ISD::FADD:
+  case ISD::FSUB:
+  case ISD::FMUL:
+  case ISD::FMA:
+    ReplaceF32x2Op(N, DAG, Results);
+    return;
   }
 }
 
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.h b/llvm/lib/Target/NVPTX/NVPTXISelLowering.h
index 5adf69d621552f..8fd4ded42a238a 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.h
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.h
@@ -55,6 +55,10 @@ enum NodeType : unsigned {
   FSHR_CLAMP,
   MUL_WIDE_SIGNED,
   MUL_WIDE_UNSIGNED,
+  FADD_F32X2,
+  FMUL_F32X2,
+  FSUB_F32X2,
+  FMA_F32X2,
   SETP_F16X2,
   SETP_BF16X2,
   BFE,
diff --git a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
index 7d9697e40e6aba..b0eb9bbbb2456a 100644
--- a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
+++ b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
@@ -165,6 +165,7 @@ def hasHWROT32 : Predicate<"Subtarget->hasHWROT32()">;
 def noHWROT32 : Predicate<"!Subtarget->hasHWROT32()">;
 def hasDotInstructions : Predicate<"Subtarget->hasDotInstructions()">;
 def hasTcgen05Instructions : Predicate<"Subtarget->hasTcgen05Instructions()">;
+def hasF32x2Instructions : Predicate<"Subtarget->hasF32x2Instructions()">;
 
 def True : Predicate<"true">;
 def False : Predicate<"false">;
@@ -2638,13 +2639,13 @@ class LastCallArgInstVT<NVPTXRegClass regclass, ValueType vt> :
   NVPTXInst<(outs), (ins regclass:$a), "$a",
             [(LastCallArg (i32 0), vt:$a)]>;
 
-def CallArgI64     : CallArgInst<Int64Regs>;
+def CallArgI64     : CallArgInstVT<Int64Regs, i64>;
 def CallArgI32     : CallArgInstVT<Int32Regs, i32>;
 def CallArgI16     : CallArgInstVT<Int16Regs, i16>;
 def CallArgF64     : CallArgInst<Float64Regs>;
 def CallArgF32     : CallArgInst<Float32Regs>;
 
-def LastCallArgI64 : LastCallArgInst<Int64Regs>;
+def LastCallArgI64 : LastCallArgInstVT<Int64Regs, i64>;
 def LastCallArgI32 : LastCallArgInstVT<Int32Regs, i32>;
 def LastCallArgI16 : LastCallArgInstVT<Int16Regs, i16>;
 def LastCallArgF64 : LastCallArgInst<Float64Regs>;
@@ -3371,6 +3372,9 @@ let hasSideEffects = false in {
   def V2F32toF64 : NVPTXInst<(outs Float64Regs:$d),
                              (ins Float32Regs:$s1, Float32Regs:$s2),
                              "mov.b64 \t$d, {{$s1, $s2}};", []>;
+  def V2F32toI64 : NVPTXInst<(outs Int64Regs:$d),
+                             (ins Float32Regs:$s1, Float32Regs:$s2),
+                             "mov.b64 \t$d, {{$s1, $s2}};", []>;
 
   // unpack a larger int register to a set of smaller int registers
   def I64toV4I16 : NVPTXInst<(outs Int16Regs:$d1, Int16Regs:$d2,
@@ -3383,6 +3387,9 @@ let hasSideEffects = false in {
   def I64toV2I32 : NVPTXInst<(outs Int32Regs:$d1, Int32Regs:$d2),
                              (ins Int64Regs:$s),
                              "mov.b64 \t{{$d1, $d2}}, $...
[truncated]

Footnotes

  1. Introduced in PTX 8.6: https://docs.nvidia.com/cuda/parallel-thread-execution/#changes-in-ptx-isa-version-8-6

Copy link
Member

@AlexMaclean AlexMaclean left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Rather than legalizing v2f32, we handle these four instructions ad hoc, so that codegen remains the same unless these instructions are present.

Supporting v2f32 (similar to how we support v2f16 for example) would be a cleaner and more extensible way to implement this change.

@Prince781
Copy link
Contributor Author

Prince781 commented Feb 8, 2025

@AlexMaclean what led me to this implementation (and I did try it the other way) is that v2f16 and v2bf16 are supported by many kinds of instructions, so it makes more sense to legalize these types than v2f32. My concern is whether this feature should change code that uses f32 vectors but avoids these operations. Legalizing this type requires me to change some things in how we lower instructions, like loads and stores (for example, if we don't want ld.v2.f32 to always become ld.b64), override a few things in TLI, etc.

See the test cases for more examples. If this is not a concern, then I can go with the other implementation.

@Prince781
Copy link
Contributor Author

Okay, so I'm in the process of reworking this PR to legalize v2f32. If there's interest in the current approach, feel free to review. Otherwise, please wait until I update this.

@Prince781 Prince781 force-pushed the dev/pferro/nvptx-f32x2 branch 5 times, most recently from f46de86 to a00289e Compare March 13, 2025 10:43
@Prince781
Copy link
Contributor Author

Prince781 commented Mar 13, 2025

Hi,

I've updated the patch to legalize v2f32 as i64 and plumbed through the various areas. I've also introduced DAGCombine rules to ensure we still codegen ld.vN / st.vN when it makes sense to do so (when all accesses are of elements of v2f32). Along with some other rules for bitwise operations, this results in PTX codegen looking roughly the same if you're not using .f32x2 instructions.

The patch probably needs more coverage and I'd appreciate suggestions in that direction.

@Prince781 Prince781 force-pushed the dev/pferro/nvptx-f32x2 branch 2 times, most recently from 774cfd0 to 520da79 Compare March 15, 2025 02:00
@Prince781 Prince781 force-pushed the dev/pferro/nvptx-f32x2 branch from 520da79 to 07e7869 Compare April 1, 2025 20:35
@Prince781 Prince781 force-pushed the dev/pferro/nvptx-f32x2 branch 4 times, most recently from ad79e84 to cd1ca40 Compare April 10, 2025 22:36
@Prince781
Copy link
Contributor Author

Pinging reviewers: @AlexMaclean, @Artem-B, @arsenm, @justinfargnoli, @durga4github

@Prince781 Prince781 force-pushed the dev/pferro/nvptx-f32x2 branch 2 times, most recently from c6ed6a8 to 2b31795 Compare April 11, 2025 02:03
Copy link
Member

@AlexMaclean AlexMaclean left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Left some comments but will need to review more thoroughly. This change is quite large and complex, would it be possible to break it up in any way? Such as adding some of the DAG combine rules in a separate change?

Comment on lines 1470 to 1476
switch (StoreVT.getSimpleVT().SimpleTy) {
case MVT::v2f16:
case MVT::v2bf16:
case MVT::v2i16:
case MVT::v4i8:
case MVT::v2f32:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm seeing something like this in lots of place (Isv2x16VT(VT) || VT == MVT::v4i8 || VT == MVT::v2f32). Lets add a helper function in NVPTXUtilities.h and use that throughout.

@Prince781 Prince781 force-pushed the dev/pferro/nvptx-f32x2 branch from 2b31795 to 8cb359c Compare May 6, 2025 07:58
bangtianliu added a commit to iree-org/llvm-project that referenced this pull request Jul 17, 2025
Artem-B added a commit to Artem-B/llvm-project that referenced this pull request Jul 17, 2025
bangtianliu added a commit to iree-org/llvm-project that referenced this pull request Jul 18, 2025
@npanchen
Copy link
Contributor

@Prince781 due to this change we have lots of runtime errors, i.e. accuracy mismatch which #149393 does not solve.
If @bangtianliu observed this too, it's worth to revert that change.

@npanchen
Copy link
Contributor

@Prince781 due to this change we have lots of runtime errors, i.e. accuracy mismatch which #149393 does not solve. If @bangtianliu observed this too, it's worth to revert that change.

The main difference I see is right after isel:

; good ir

%61:b32 = CVT_bf16x2_f32 killed %60:b32, killed %59:b32, 5
%62:b32 = CVT_bf16x2_f32 killed %58:b32, killed %57:b32, 5
...
ST_i32 %62:b32,...
ST_i32 %61:b32,...
...
; bad ir

%61:b64 = V2I32toI64 killed %59:b32, killed %60:b32
%62:b64 = V2I32toI64 killed %57:b32, killed %58:b32
...
ST_i64 %62:b64,...
ST_i64 %61:b64,...

i.e. after that change there were no f32->bf16 conversions generated

@npanchen
Copy link
Contributor

smallest reproducer:

define ptx_kernel void @test(float %0) {
  %2 = insertelement <2 x float> zeroinitializer, float %0, i64 0
  %3 = fptrunc <2 x float> %2 to <2 x bfloat>
  store <2 x bfloat> %3, ptr addrspace(1) null, align 2
  ret void
}

without the change isel generates:

bb.0 (%ir-block.1):
  %0:b32 = LD_i32 0, 0, 101, 3, 32, &test_param_0, 0 :: (dereferenceable invariant load (s32), addrspace 101)
  %1:b16 = CVT_bf16_f32 killed %0:b32, 5
  %3:b16 = IMPLICIT_DEF
  %2:b32 = V2I16toI32 killed %1:b16, killed %3:b16
  %4:b64 = IMOV64i 0
  ST_i32 killed %2:b32, 0, 0, 1, 16, killed %4:b64, 0 :: (store (s16) into `ptr addrspace(1) null`, addrspace 1)
  %5:b64 = IMOV64i 2
  ST_i32 0, 0, 0, 1, 16, killed %5:b64, 0 :: (store (s16) into `ptr addrspace(1) null` + 2, addrspace 1)
  Return

with the change:

bb.0 (%ir-block.1):
  %0:b32 = LD_i32 0, 0, 101, 3, 32, &test_param_0, 0 :: (dereferenceable invariant load (s32), addrspace 101)
  %1:b32 = FMOV32i float 0.000000e+00
  %2:b64 = V2I32toI64 killed %0:b32, killed %1:b32
  %3:b64 = IMOV64i 0
  ST_i64 %2:b64, 0, 0, 1, 16, killed %3:b64, 0 :: (store (s16) into `ptr addrspace(1) null`, addrspace 1)
  %4:b64 = SRLi64ri %2:b64, 16
  %5:b64 = IMOV64i 2
  ST_i64 killed %4:b64, 0, 0, 1, 16, killed %5:b64, 0 :: (store (s16) into `ptr addrspace(1) null` + 2, addrspace 1)
  %6:b64 = IMOV64i 6
  ST_i64 0, 0, 0, 1, 16, killed %6:b64, 0 :: (store (s16) into `ptr addrspace(1) null` + 6, addrspace 1)
  %7:b64 = IMOV64i 4
  ST_i64 0, 0, 0, 1, 16, killed %7:b64, 0 :: (store (s16) into `ptr addrspace(1) null` + 4, addrspace 1)
  Return

cc @Artem-B

@Artem-B
Copy link
Member

Artem-B commented Jul 18, 2025

Can you elaborate on what is the remaining issue with #149393 applied?

Looking at the differences from the older LLC vs the HEAD, it appears that the main difference is in how we're handling NaNs: https://godbolt.org/z/MfrYcTK5d

Yeah, that might be another issue.

@AlexMaclean
Copy link
Member

Looks like we need to add the following:

  setTruncStoreAction(MVT::v2f32, MVT::v2f16, Expand);
  setTruncStoreAction(MVT::v2f32, MVT::v2bf16, Expand);

Otherwise the fp_round will get folded into the store which we can't handle.

@npanchen
Copy link
Contributor

Looks like we need to add the following:

  setTruncStoreAction(MVT::v2f32, MVT::v2f16, Expand);
  setTruncStoreAction(MVT::v2f32, MVT::v2bf16, Expand);

Otherwise the fp_round will get folded into the store which we can't handle.

Yes, I finally got a chance to debug it and realized there's incorrect handling of truncating stores for a long period of time. Will you(@AlexMaclean) or @Artem-B or @Prince781 create a fix ?

@npanchen
Copy link
Contributor

Looks like we need to add the following:

  setTruncStoreAction(MVT::v2f32, MVT::v2f16, Expand);
  setTruncStoreAction(MVT::v2f32, MVT::v2bf16, Expand);

Otherwise the fp_round will get folded into the store which we can't handle.

With both #149393 and #149571 my local testing is clean!

bjacob pushed a commit to iree-org/llvm-project that referenced this pull request Jul 21, 2025
@npanchen
Copy link
Contributor

there's one more runtime issue I later hit due to that change. Here are 2 ptxes (bad, good) with and without that change. I'm by no means expert in PTX and to me packs, unpacks are correct. @Artem-B @Prince781 @AlexMaclean can you please take a look?

@AlexMaclean
Copy link
Member

@npanchen Would it be possible to attach the IR that the PTX was generated from?

@npanchen
Copy link
Contributor

npanchen commented Jul 21, 2025

@npanchen Would it be possible to attach the IR that the PTX was generated from?

Absolutely: https://gist.github.com/npanchen/f3c5b5b657cb2a04a17c0b5467090383
Since I don't understand what's wrong, I cannot help to reduce the test case. Definitely will be glad to know what has happened and why "bad" PTX is bad

upd: that's pre-llc LLVM IR. Obviously IR is identical before and after that change

@npanchen
Copy link
Contributor

@AlexMaclean any updates ?

@npanchen
Copy link
Contributor

@AlexMaclean @Prince781 ping

@Prince781
Copy link
Contributor Author

@npanchen I compared the "good" and "bad" PTX and I notice there are now mov.b64s inside your async wgmma pipeline (between wgmma.mma_async calls). Register accesses inside are disallowed1 and ptxas will force wgmma to be synchronous.

Footnotes

  1. See PTX ISA for wgmma.mma_async: "wgmma.fence instruction must be used to fence the register accesses of wgmma.mma_async instruction from their prior accesses. Otherwise, the behavior is undefined."

@Prince781
Copy link
Contributor Author

Prince781 commented Jul 22, 2025

The problem is inadequate modeling of the semantics of wgmma.mma_async in inline ASM. You have extractelements between your wgmma.fence and mma, which breaks WGMMA semantics. I think having a wgmma.mma_async intrinsic implementation in LLVM IR would fix this issue. Was the original source code for this kernel in CUDA C++?

@Prince781
Copy link
Contributor Author

@AlexMaclean I don't think we need to revert as this is exposing code that was technically already broken. We need to improve the semantics of wgmma.mma_async with an intrinsic. For now there are workarounds at the source level that other library vendors use.

@Prince781
Copy link
Contributor Author

@npanchen if the original source was CUDA C++, then you can use the same trick CUTLASS uses where you pass the operand through a asm volatile to get the desired anti-dependency with your wgmma.mma_async inline ASM call.

@Prince781
Copy link
Contributor Author

@npanchen it looks like there's a similar thing at the LLVM IR level (@llvm.arithmetic.fence) you could also look into, although I haven't tried it myself.

@npanchen
Copy link
Contributor

npanchen commented Jul 22, 2025

@Prince781 thanks for the explanation!
That code actually comes from Mojo. If I got the restriction right, there should be no writes to the register after the fence and prior to the use by wgmma.mma_async instruction. If so it does seem to be missed peephole optimization for high-level languages so that it will move such instructions before the fence for trivial cases like that.
For example, if I just consider CUDA C++ and this example https://github.com/bertmaher/simplegemm/blob/main/pingpong.cu#L458-L471, the way it's written is:

  1. fence is generated before the loop
  2. inline assembly uses d[x][y]

there should be a load(gep) after the fence, i.e. similar to extractelement in our case. Unless that code is invalid too, there could be a transformation that moves these instructions out or it's simply expected that mem2reg will happen, i.e. it won't work with O0 correctly too.

@npanchen
Copy link
Contributor

@Prince781 btw, another concern I got from internal developers is that this PR potentially will impact perf badly for sm_90[a]. At least SASS does have more instructions (with more fences that will be even more). Is there a plan to limit new packing mechanism to only use cases mentioned in description and the sm_100+ target ?

@Prince781
Copy link
Contributor Author

out or it's simply expected that mem2reg will happen, i.e. it won't work with O0 correctly too.

Yes, this is actually necessary to get rid of the intermediate registers. However, it will still work correctly in O0, it will just be serialized by the compiler (ptxas) as a fail-safe, so you'll lose a lot of perf.

btw, another concern I got from internal developers is that this PR potentially will impact perf badly for sm_90[a]. At least SASS does have more instructions (with more fences that will be even more). Is there a plan to limit new packing mechanism to only use cases mentioned in description and the sm_100+ target ?

Perhaps we can improve things pre-sm_100 with more DAGCombiner rules. I can take a closer look at your example later. If you have any more please post them here. If you're talking about using arithmetic fences, those shouldn't compile to anything in SASS.

@npanchen
Copy link
Contributor

Yes, this is actually necessary to get rid of the intermediate registers. However, it will still work correctly in O0, it will just be serialized by the compiler (ptxas) as a fail-safe, so you'll lose a lot of perf.

If that's safe, I don't understand what's a difference between load(gep) and extractelement. If latter has UB, why former does not ? Can you elaborate or maybe point me to a explanation document ?

Perhaps we can improve things pre-sm_100 with more DAGCombiner rules. I can take a closer look at your example later. If you have any more please post them here. If you're talking about using arithmetic fences, those shouldn't compile to anything in SASS.

As soon as I will complete pulldown into our internal repo and if we will get regressions, I will definitely let you know.
I was only referring to wgmma.fence. We don't use arithmetic fence and its behavior is clear to me, because it has well defined scope.

@Prince781
Copy link
Contributor Author

If that's safe, I don't understand what's a difference between load(gep) and extractelement. If latter has UB, why former does not ? Can you elaborate or maybe point me to a explanation document ?

They both are undefined according to the PTX spec. But what the compiler ptxas will actually do is serialize the WGMMAs. You may even get a warning.

@Artem-B
Copy link
Member

Artem-B commented Jul 29, 2025

@Prince781 it appears that the PTX generated after this patch (and subsequent fixes) triggers a miscompilation in ptxas.

So far it shows up as an invalid result of a fairly large computation on one test that resists being reduced, but the issue goes away if the PTX is compiled with lower ptxas optimization settings. The issue shows up with both cuda 12.8.1 and 12.9.1.

Attached is the IR and PTX files for the working/failing kernel, but I have no test usable for reducing the kernel further.
ptxas-miscompile-reproducer.zip

@AlexMaclean, @gonzalobg : ^^^ FYI.

@Prince781
Copy link
Contributor Author

@Artem-B I'm taking a look. So far I see no change in the number of instructions outside of ld, st, and mov. When adding up ld and st for b32 and b64, the total data is unchanged.

% icdiff <(histptx good/module_0013.jit_svd.ptx) <(histptx bad5/module_0013.jit_svd.ptx)                                                                                   25-07-29 - 17:46:59
/proc/self/fd/11                                                                               /proc/self/fd/12                                                                              
      4 cvt.u16.u32                                                                                  4 cvt.u16.u32                                                                           
      5 cvt.u32.u16                                                                                  5 cvt.u32.u16                                                                           
      3 cvt.u64.u16                                                                                  3 cvt.u64.u16                                                                           
     61 cvt.u64.u32                                                                                 61 cvt.u64.u32                                                                           
     24 div.full.f32                                                                                24 div.full.f32                                                                          
    288 ld.b32                                                                                     144 ld.b64                                                                                
     25 ld.global.nc.b32                                                                            25 ld.global.nc.b32                                                                      
     46 ld.global.nc.v2.b32                                                                         46 ld.global.nc.v2.b32                                                                   
     32 ld.param.b64                                                                                32 ld.param.b64                                                                          
     96 ld.shared.b32                                                                               96 ld.shared.b32                                                                         
      2 mad.lo.s16                                                                                   2 mad.lo.s16                                                                            
     45 mad.lo.s32                                                                                  45 mad.lo.s32                                                                            
    122 mov.b32                                                                                     18 mov.b32                                                                               
     32 mov.b64                                                                                    184 mov.b64                                                                               
      6 mov.pred                                                                                     6 mov.pred                                                                              
     22 mov.u32                                                                                     22 mov.u32                                                                               
      2 mul.hi.u16                                                                                   2 mul.hi.u16                                                                            
      2 mul.hi.u32                                                                                   2 mul.hi.u32                                                                            
      9 mul.lo.s16                                                                                   9 mul.lo.s16                                                                            
---                                                                                            ---                                                                                           
      8 setp.num.f32                                                                                 8 setp.num.f32                                                                          
     16 shl.b32                                                                                     16 shl.b32                                                                               
     43 shl.b64                                                                                     43 shl.b64                                                                               
     14 shr.u16                                                                                     14 shr.u16                                                                               
      8 shr.u32                                                                                      8 shr.u32                                                                               
    306 st.b32                                                                                      18 st.b32                                                                                
                                                                                                   144 st.b64                                                                                
      3 st.global.b32                                                                                3 st.global.b32                                                                         
     51 st.global.v2.b32                                                                            51 st.global.v2.b32                                                                      
      1 st.relaxed.sys.global.b32                                                                    1 st.relaxed.sys.global.b32                                                             
     96 st.shared.b32                                                                               96 st.shared.b32                                                                         
     67 sub.rn.f32                                                                                  67 sub.rn.f32

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants