Skip to content

[NVPTX] Fix v2i8 call lowering, use generic ld/st nodes for call params #146930

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Jul 28, 2025

Conversation

AlexMaclean
Copy link
Member

No description provided.

@llvmbot
Copy link
Member

llvmbot commented Jul 3, 2025

@llvm/pr-subscribers-llvm-selectiondag
@llvm/pr-subscribers-debuginfo
@llvm/pr-subscribers-clang

@llvm/pr-subscribers-backend-nvptx

Author: Alex MacLean (AlexMaclean)

Changes

Patch is 241.60 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/146930.diff

37 Files Affected:

  • (modified) llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp (-273)
  • (modified) llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.h (-2)
  • (modified) llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp (+227-364)
  • (modified) llvm/lib/Target/NVPTX/NVPTXISelLowering.h (+2-8)
  • (modified) llvm/lib/Target/NVPTX/NVPTXInstrInfo.td (+3-130)
  • (modified) llvm/test/CodeGen/NVPTX/bf16x2-instructions.ll (+3-3)
  • (modified) llvm/test/CodeGen/NVPTX/byval-const-global.ll (+4-4)
  • (modified) llvm/test/CodeGen/NVPTX/call-with-alloca-buffer.ll (+5-5)
  • (modified) llvm/test/CodeGen/NVPTX/call_bitcast_byval.ll (+8-8)
  • (modified) llvm/test/CodeGen/NVPTX/combine-mad.ll (+2-2)
  • (modified) llvm/test/CodeGen/NVPTX/compare-int.ll (+501-120)
  • (modified) llvm/test/CodeGen/NVPTX/convert-call-to-indirect.ll (+162-10)
  • (modified) llvm/test/CodeGen/NVPTX/dynamic_stackalloc.ll (+2-2)
  • (modified) llvm/test/CodeGen/NVPTX/f16x2-instructions.ll (+6-6)
  • (modified) llvm/test/CodeGen/NVPTX/fma.ll (+4-4)
  • (modified) llvm/test/CodeGen/NVPTX/forward-ld-param.ll (+1-1)
  • (modified) llvm/test/CodeGen/NVPTX/i128-param.ll (+10-10)
  • (modified) llvm/test/CodeGen/NVPTX/i16x2-instructions.ll (+6-6)
  • (modified) llvm/test/CodeGen/NVPTX/i8x2-instructions.ll (+93-28)
  • (modified) llvm/test/CodeGen/NVPTX/i8x4-instructions.ll (+6-6)
  • (modified) llvm/test/CodeGen/NVPTX/idioms.ll (+1-1)
  • (modified) llvm/test/CodeGen/NVPTX/indirect_byval.ll (+10-10)
  • (modified) llvm/test/CodeGen/NVPTX/lower-args-gridconstant.ll (+14-19)
  • (modified) llvm/test/CodeGen/NVPTX/lower-args.ll (+4-6)
  • (modified) llvm/test/CodeGen/NVPTX/lower-byval-args.ll (+1-1)
  • (modified) llvm/test/CodeGen/NVPTX/misched_func_call.ll (+7-8)
  • (modified) llvm/test/CodeGen/NVPTX/param-add.ll (+4-4)
  • (modified) llvm/test/CodeGen/NVPTX/param-load-store.ll (+124-134)
  • (modified) llvm/test/CodeGen/NVPTX/param-overalign.ll (+2-2)
  • (modified) llvm/test/CodeGen/NVPTX/param-vectorize-device.ll (+14-14)
  • (modified) llvm/test/CodeGen/NVPTX/proxy-reg-erasure.mir (+2-2)
  • (modified) llvm/test/CodeGen/NVPTX/st-param-imm.ll (+146-75)
  • (modified) llvm/test/CodeGen/NVPTX/store-undef.ll (+2-2)
  • (modified) llvm/test/CodeGen/NVPTX/tex-read-cuda.ll (+1-1)
  • (modified) llvm/test/CodeGen/NVPTX/unaligned-param-load-store.ll (+242-352)
  • (modified) llvm/test/CodeGen/NVPTX/vaargs.ll (+8-8)
  • (modified) llvm/test/CodeGen/NVPTX/variadics-backend.ll (+15-17)
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp
index 5631342ecc13e..fdd2671a5289e 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp
@@ -145,18 +145,6 @@ void NVPTXDAGToDAGISel::Select(SDNode *N) {
     if (tryStoreVector(N))
       return;
     break;
-  case NVPTXISD::LoadParam:
-  case NVPTXISD::LoadParamV2:
-  case NVPTXISD::LoadParamV4:
-    if (tryLoadParam(N))
-      return;
-    break;
-  case NVPTXISD::StoreParam:
-  case NVPTXISD::StoreParamV2:
-  case NVPTXISD::StoreParamV4:
-    if (tryStoreParam(N))
-      return;
-    break;
   case ISD::INTRINSIC_W_CHAIN:
     if (tryIntrinsicChain(N))
       return;
@@ -1429,267 +1417,6 @@ bool NVPTXDAGToDAGISel::tryStoreVector(SDNode *N) {
   return true;
 }
 
-bool NVPTXDAGToDAGISel::tryLoadParam(SDNode *Node) {
-  SDValue Chain = Node->getOperand(0);
-  SDValue Offset = Node->getOperand(2);
-  SDValue Glue = Node->getOperand(3);
-  SDLoc DL(Node);
-  MemSDNode *Mem = cast<MemSDNode>(Node);
-
-  unsigned VecSize;
-  switch (Node->getOpcode()) {
-  default:
-    return false;
-  case NVPTXISD::LoadParam:
-    VecSize = 1;
-    break;
-  case NVPTXISD::LoadParamV2:
-    VecSize = 2;
-    break;
-  case NVPTXISD::LoadParamV4:
-    VecSize = 4;
-    break;
-  }
-
-  EVT EltVT = Node->getValueType(0);
-  EVT MemVT = Mem->getMemoryVT();
-
-  std::optional<unsigned> Opcode;
-
-  switch (VecSize) {
-  default:
-    return false;
-  case 1:
-    Opcode = pickOpcodeForVT(MemVT.getSimpleVT().SimpleTy,
-                             NVPTX::LoadParamMemI8, NVPTX::LoadParamMemI16,
-                             NVPTX::LoadParamMemI32, NVPTX::LoadParamMemI64);
-    break;
-  case 2:
-    Opcode =
-        pickOpcodeForVT(MemVT.getSimpleVT().SimpleTy, NVPTX::LoadParamMemV2I8,
-                        NVPTX::LoadParamMemV2I16, NVPTX::LoadParamMemV2I32,
-                        NVPTX::LoadParamMemV2I64);
-    break;
-  case 4:
-    Opcode = pickOpcodeForVT(MemVT.getSimpleVT().SimpleTy,
-                             NVPTX::LoadParamMemV4I8, NVPTX::LoadParamMemV4I16,
-                             NVPTX::LoadParamMemV4I32, {/* no v4i64 */});
-    break;
-  }
-  if (!Opcode)
-    return false;
-
-  SDVTList VTs;
-  if (VecSize == 1) {
-    VTs = CurDAG->getVTList(EltVT, MVT::Other, MVT::Glue);
-  } else if (VecSize == 2) {
-    VTs = CurDAG->getVTList(EltVT, EltVT, MVT::Other, MVT::Glue);
-  } else {
-    EVT EVTs[] = { EltVT, EltVT, EltVT, EltVT, MVT::Other, MVT::Glue };
-    VTs = CurDAG->getVTList(EVTs);
-  }
-
-  unsigned OffsetVal = Offset->getAsZExtVal();
-
-  SmallVector<SDValue, 2> Ops(
-      {CurDAG->getTargetConstant(OffsetVal, DL, MVT::i32), Chain, Glue});
-
-  ReplaceNode(Node, CurDAG->getMachineNode(*Opcode, DL, VTs, Ops));
-  return true;
-}
-
-// Helpers for constructing opcode (ex: NVPTX::StoreParamV4F32_iiri)
-#define getOpcV2H(ty, opKind0, opKind1)                                        \
-  NVPTX::StoreParamV2##ty##_##opKind0##opKind1
-
-#define getOpcV2H1(ty, opKind0, isImm1)                                        \
-  (isImm1) ? getOpcV2H(ty, opKind0, i) : getOpcV2H(ty, opKind0, r)
-
-#define getOpcodeForVectorStParamV2(ty, isimm)                                 \
-  (isimm[0]) ? getOpcV2H1(ty, i, isimm[1]) : getOpcV2H1(ty, r, isimm[1])
-
-#define getOpcV4H(ty, opKind0, opKind1, opKind2, opKind3)                      \
-  NVPTX::StoreParamV4##ty##_##opKind0##opKind1##opKind2##opKind3
-
-#define getOpcV4H3(ty, opKind0, opKind1, opKind2, isImm3)                      \
-  (isImm3) ? getOpcV4H(ty, opKind0, opKind1, opKind2, i)                       \
-           : getOpcV4H(ty, opKind0, opKind1, opKind2, r)
-
-#define getOpcV4H2(ty, opKind0, opKind1, isImm2, isImm3)                       \
-  (isImm2) ? getOpcV4H3(ty, opKind0, opKind1, i, isImm3)                       \
-           : getOpcV4H3(ty, opKind0, opKind1, r, isImm3)
-
-#define getOpcV4H1(ty, opKind0, isImm1, isImm2, isImm3)                        \
-  (isImm1) ? getOpcV4H2(ty, opKind0, i, isImm2, isImm3)                        \
-           : getOpcV4H2(ty, opKind0, r, isImm2, isImm3)
-
-#define getOpcodeForVectorStParamV4(ty, isimm)                                 \
-  (isimm[0]) ? getOpcV4H1(ty, i, isimm[1], isimm[2], isimm[3])                 \
-             : getOpcV4H1(ty, r, isimm[1], isimm[2], isimm[3])
-
-#define getOpcodeForVectorStParam(n, ty, isimm)                                \
-  (n == 2) ? getOpcodeForVectorStParamV2(ty, isimm)                            \
-           : getOpcodeForVectorStParamV4(ty, isimm)
-
-static unsigned pickOpcodeForVectorStParam(SmallVector<SDValue, 8> &Ops,
-                                           unsigned NumElts,
-                                           MVT::SimpleValueType MemTy,
-                                           SelectionDAG *CurDAG, SDLoc DL) {
-  // Determine which inputs are registers and immediates make new operators
-  // with constant values
-  SmallVector<bool, 4> IsImm(NumElts, false);
-  for (unsigned i = 0; i < NumElts; i++) {
-    IsImm[i] = (isa<ConstantSDNode>(Ops[i]) || isa<ConstantFPSDNode>(Ops[i]));
-    if (IsImm[i]) {
-      SDValue Imm = Ops[i];
-      if (MemTy == MVT::f32 || MemTy == MVT::f64) {
-        const ConstantFPSDNode *ConstImm = cast<ConstantFPSDNode>(Imm);
-        const ConstantFP *CF = ConstImm->getConstantFPValue();
-        Imm = CurDAG->getTargetConstantFP(*CF, DL, Imm->getValueType(0));
-      } else {
-        const ConstantSDNode *ConstImm = cast<ConstantSDNode>(Imm);
-        const ConstantInt *CI = ConstImm->getConstantIntValue();
-        Imm = CurDAG->getTargetConstant(*CI, DL, Imm->getValueType(0));
-      }
-      Ops[i] = Imm;
-    }
-  }
-
-  // Get opcode for MemTy, size, and register/immediate operand ordering
-  switch (MemTy) {
-  case MVT::i8:
-    return getOpcodeForVectorStParam(NumElts, I8, IsImm);
-  case MVT::i16:
-    return getOpcodeForVectorStParam(NumElts, I16, IsImm);
-  case MVT::i32:
-    return getOpcodeForVectorStParam(NumElts, I32, IsImm);
-  case MVT::i64:
-    assert(NumElts == 2 && "MVT too large for NumElts > 2");
-    return getOpcodeForVectorStParamV2(I64, IsImm);
-  case MVT::f32:
-    return getOpcodeForVectorStParam(NumElts, F32, IsImm);
-  case MVT::f64:
-    assert(NumElts == 2 && "MVT too large for NumElts > 2");
-    return getOpcodeForVectorStParamV2(F64, IsImm);
-
-  // These cases don't support immediates, just use the all register version
-  // and generate moves.
-  case MVT::i1:
-    return (NumElts == 2) ? NVPTX::StoreParamV2I8_rr
-                          : NVPTX::StoreParamV4I8_rrrr;
-  case MVT::f16:
-  case MVT::bf16:
-    return (NumElts == 2) ? NVPTX::StoreParamV2I16_rr
-                          : NVPTX::StoreParamV4I16_rrrr;
-  case MVT::v2f16:
-  case MVT::v2bf16:
-  case MVT::v2i16:
-  case MVT::v4i8:
-    return (NumElts == 2) ? NVPTX::StoreParamV2I32_rr
-                          : NVPTX::StoreParamV4I32_rrrr;
-  default:
-    llvm_unreachable("Cannot select st.param for unknown MemTy");
-  }
-}
-
-bool NVPTXDAGToDAGISel::tryStoreParam(SDNode *N) {
-  SDLoc DL(N);
-  SDValue Chain = N->getOperand(0);
-  SDValue Param = N->getOperand(1);
-  unsigned ParamVal = Param->getAsZExtVal();
-  SDValue Offset = N->getOperand(2);
-  unsigned OffsetVal = Offset->getAsZExtVal();
-  MemSDNode *Mem = cast<MemSDNode>(N);
-  SDValue Glue = N->getOperand(N->getNumOperands() - 1);
-
-  // How many elements do we have?
-  unsigned NumElts;
-  switch (N->getOpcode()) {
-  default:
-    llvm_unreachable("Unexpected opcode");
-  case NVPTXISD::StoreParam:
-    NumElts = 1;
-    break;
-  case NVPTXISD::StoreParamV2:
-    NumElts = 2;
-    break;
-  case NVPTXISD::StoreParamV4:
-    NumElts = 4;
-    break;
-  }
-
-  // Build vector of operands
-  SmallVector<SDValue, 8> Ops;
-  for (unsigned i = 0; i < NumElts; ++i)
-    Ops.push_back(N->getOperand(i + 3));
-  Ops.append({CurDAG->getTargetConstant(ParamVal, DL, MVT::i32),
-              CurDAG->getTargetConstant(OffsetVal, DL, MVT::i32), Chain, Glue});
-
-  // Determine target opcode
-  // If we have an i1, use an 8-bit store. The lowering code in
-  // NVPTXISelLowering will have already emitted an upcast.
-  std::optional<unsigned> Opcode;
-  switch (NumElts) {
-  default:
-    llvm_unreachable("Unexpected NumElts");
-  case 1: {
-    MVT::SimpleValueType MemTy = Mem->getMemoryVT().getSimpleVT().SimpleTy;
-    SDValue Imm = Ops[0];
-    if (MemTy != MVT::f16 && MemTy != MVT::bf16 &&
-        (isa<ConstantSDNode>(Imm) || isa<ConstantFPSDNode>(Imm))) {
-      // Convert immediate to target constant
-      if (MemTy == MVT::f32 || MemTy == MVT::f64) {
-        const ConstantFPSDNode *ConstImm = cast<ConstantFPSDNode>(Imm);
-        const ConstantFP *CF = ConstImm->getConstantFPValue();
-        Imm = CurDAG->getTargetConstantFP(*CF, DL, Imm->getValueType(0));
-      } else {
-        const ConstantSDNode *ConstImm = cast<ConstantSDNode>(Imm);
-        const ConstantInt *CI = ConstImm->getConstantIntValue();
-        Imm = CurDAG->getTargetConstant(*CI, DL, Imm->getValueType(0));
-      }
-      Ops[0] = Imm;
-      // Use immediate version of store param
-      Opcode =
-          pickOpcodeForVT(MemTy, NVPTX::StoreParamI8_i, NVPTX::StoreParamI16_i,
-                          NVPTX::StoreParamI32_i, NVPTX::StoreParamI64_i);
-    } else
-      Opcode = pickOpcodeForVT(Mem->getMemoryVT().getSimpleVT().SimpleTy,
-                               NVPTX::StoreParamI8_r, NVPTX::StoreParamI16_r,
-                               NVPTX::StoreParamI32_r, NVPTX::StoreParamI64_r);
-    if (Opcode == NVPTX::StoreParamI8_r) {
-      // Fine tune the opcode depending on the size of the operand.
-      // This helps to avoid creating redundant COPY instructions in
-      // InstrEmitter::AddRegisterOperand().
-      switch (Ops[0].getSimpleValueType().SimpleTy) {
-      default:
-        break;
-      case MVT::i32:
-        Opcode = NVPTX::StoreParamI8TruncI32_r;
-        break;
-      case MVT::i64:
-        Opcode = NVPTX::StoreParamI8TruncI64_r;
-        break;
-      }
-    }
-    break;
-  }
-  case 2:
-  case 4: {
-    MVT::SimpleValueType MemTy = Mem->getMemoryVT().getSimpleVT().SimpleTy;
-    Opcode = pickOpcodeForVectorStParam(Ops, NumElts, MemTy, CurDAG, DL);
-    break;
-  }
-  }
-
-  SDVTList RetVTs = CurDAG->getVTList(MVT::Other, MVT::Glue);
-  SDNode *Ret = CurDAG->getMachineNode(*Opcode, DL, RetVTs, Ops);
-  MachineMemOperand *MemRef = cast<MemSDNode>(N)->getMemOperand();
-  CurDAG->setNodeMemRefs(cast<MachineSDNode>(Ret), {MemRef});
-
-  ReplaceNode(N, Ret);
-  return true;
-}
-
 /// SelectBFE - Look for instruction sequences that can be made more efficient
 /// by using the 'bfe' (bit-field extract) PTX instruction
 bool NVPTXDAGToDAGISel::tryBFE(SDNode *N) {
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.h b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.h
index 0e4dec1adca67..19b569d638b0c 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.h
+++ b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.h
@@ -78,8 +78,6 @@ class LLVM_LIBRARY_VISIBILITY NVPTXDAGToDAGISel : public SelectionDAGISel {
   bool tryLDG(MemSDNode *N);
   bool tryStore(SDNode *N);
   bool tryStoreVector(SDNode *N);
-  bool tryLoadParam(SDNode *N);
-  bool tryStoreParam(SDNode *N);
   bool tryFence(SDNode *N);
   void SelectAddrSpaceCast(SDNode *N);
   bool tryBFE(SDNode *N);
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
index bb0aeb493ed48..3cda7df55ad58 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
@@ -1049,12 +1049,6 @@ const char *NVPTXTargetLowering::getTargetNodeName(unsigned Opcode) const {
     MAKE_CASE(NVPTXISD::DeclareArrayParam)
     MAKE_CASE(NVPTXISD::DeclareScalarParam)
     MAKE_CASE(NVPTXISD::CALL)
-    MAKE_CASE(NVPTXISD::LoadParam)
-    MAKE_CASE(NVPTXISD::LoadParamV2)
-    MAKE_CASE(NVPTXISD::LoadParamV4)
-    MAKE_CASE(NVPTXISD::StoreParam)
-    MAKE_CASE(NVPTXISD::StoreParamV2)
-    MAKE_CASE(NVPTXISD::StoreParamV4)
     MAKE_CASE(NVPTXISD::MoveParam)
     MAKE_CASE(NVPTXISD::UNPACK_VECTOR)
     MAKE_CASE(NVPTXISD::BUILD_VECTOR)
@@ -1293,105 +1287,6 @@ Align NVPTXTargetLowering::getArgumentAlignment(const CallBase *CB, Type *Ty,
   return DL.getABITypeAlign(Ty);
 }
 
-static bool adjustElementType(EVT &ElementType) {
-  switch (ElementType.getSimpleVT().SimpleTy) {
-  default:
-    return false;
-  case MVT::f16:
-  case MVT::bf16:
-    ElementType = MVT::i16;
-    return true;
-  case MVT::f32:
-  case MVT::v2f16:
-  case MVT::v2bf16:
-    ElementType = MVT::i32;
-    return true;
-  case MVT::f64:
-    ElementType = MVT::i64;
-    return true;
-  }
-}
-
-// Use byte-store when the param address of the argument value is unaligned.
-// This may happen when the return value is a field of a packed structure.
-//
-// This is called in LowerCall() when passing the param values.
-static SDValue LowerUnalignedStoreParam(SelectionDAG &DAG, SDValue Chain,
-                                        uint64_t Offset, EVT ElementType,
-                                        SDValue StVal, SDValue &InGlue,
-                                        unsigned ArgID, const SDLoc &dl) {
-  // Bit logic only works on integer types
-  if (adjustElementType(ElementType))
-    StVal = DAG.getNode(ISD::BITCAST, dl, ElementType, StVal);
-
-  // Store each byte
-  SDVTList StoreVTs = DAG.getVTList(MVT::Other, MVT::Glue);
-  for (unsigned i = 0, n = ElementType.getSizeInBits() / 8; i < n; i++) {
-    // Shift the byte to the last byte position
-    SDValue ShiftVal = DAG.getNode(ISD::SRL, dl, ElementType, StVal,
-                                   DAG.getConstant(i * 8, dl, MVT::i32));
-    SDValue StoreOperands[] = {Chain, DAG.getConstant(ArgID, dl, MVT::i32),
-                               DAG.getConstant(Offset + i, dl, MVT::i32),
-                               ShiftVal, InGlue};
-    // Trunc store only the last byte by using
-    //     st.param.b8
-    // The register type can be larger than b8.
-    Chain = DAG.getMemIntrinsicNode(
-        NVPTXISD::StoreParam, dl, StoreVTs, StoreOperands, MVT::i8,
-        MachinePointerInfo(), Align(1), MachineMemOperand::MOStore);
-    InGlue = Chain.getValue(1);
-  }
-  return Chain;
-}
-
-// Use byte-load when the param adress of the returned value is unaligned.
-// This may happen when the returned value is a field of a packed structure.
-static SDValue
-LowerUnalignedLoadRetParam(SelectionDAG &DAG, SDValue &Chain, uint64_t Offset,
-                           EVT ElementType, SDValue &InGlue,
-                           SmallVectorImpl<SDValue> &TempProxyRegOps,
-                           const SDLoc &dl) {
-  // Bit logic only works on integer types
-  EVT MergedType = ElementType;
-  adjustElementType(MergedType);
-
-  // Load each byte and construct the whole value. Initial value to 0
-  SDValue RetVal = DAG.getConstant(0, dl, MergedType);
-  // LoadParamMemI8 loads into i16 register only
-  SDVTList LoadVTs = DAG.getVTList(MVT::i16, MVT::Other, MVT::Glue);
-  for (unsigned i = 0, n = ElementType.getSizeInBits() / 8; i < n; i++) {
-    SDValue LoadOperands[] = {Chain, DAG.getConstant(1, dl, MVT::i32),
-                              DAG.getConstant(Offset + i, dl, MVT::i32),
-                              InGlue};
-    // This will be selected to LoadParamMemI8
-    SDValue LdVal =
-        DAG.getMemIntrinsicNode(NVPTXISD::LoadParam, dl, LoadVTs, LoadOperands,
-                                MVT::i8, MachinePointerInfo(), Align(1));
-    SDValue TmpLdVal = LdVal.getValue(0);
-    Chain = LdVal.getValue(1);
-    InGlue = LdVal.getValue(2);
-
-    TmpLdVal = DAG.getNode(NVPTXISD::ProxyReg, dl,
-                           TmpLdVal.getSimpleValueType(), TmpLdVal);
-    TempProxyRegOps.push_back(TmpLdVal);
-
-    SDValue CMask = DAG.getConstant(255, dl, MergedType);
-    SDValue CShift = DAG.getConstant(i * 8, dl, MVT::i32);
-    // Need to extend the i16 register to the whole width.
-    TmpLdVal = DAG.getNode(ISD::ZERO_EXTEND, dl, MergedType, TmpLdVal);
-    // Mask off the high bits. Leave only the lower 8bits.
-    // Do this because we are using loadparam.b8.
-    TmpLdVal = DAG.getNode(ISD::AND, dl, MergedType, TmpLdVal, CMask);
-    // Shift and merge
-    TmpLdVal = DAG.getNode(ISD::SHL, dl, MergedType, TmpLdVal, CShift);
-    RetVal = DAG.getNode(ISD::OR, dl, MergedType, RetVal, TmpLdVal);
-  }
-  if (ElementType != MergedType)
-    RetVal = DAG.getNode(ISD::BITCAST, dl, ElementType, RetVal);
-
-  return RetVal;
-}
-
 static bool shouldConvertToIndirectCall(const CallBase *CB,
                                         const GlobalAddressSDNode *Func) {
   if (!Func)
@@ -1458,10 +1353,8 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
 
   SelectionDAG &DAG = CLI.DAG;
   SDLoc dl = CLI.DL;
-  SmallVectorImpl<ISD::InputArg> &Ins = CLI.Ins;
-  SDValue Chain = CLI.Chain;
+  const SmallVectorImpl<ISD::InputArg> &Ins = CLI.Ins;
   SDValue Callee = CLI.Callee;
-  bool &isTailCall = CLI.IsTailCall;
   ArgListTy &Args = CLI.getArgs();
   Type *RetTy = CLI.RetTy;
   const CallBase *CB = CLI.CB;
@@ -1492,9 +1385,34 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
   unsigned VAOffset = 0;                  // current offset in the param array
 
   const unsigned UniqueCallSite = GlobalUniqueCallSite++;
-  SDValue TempChain = Chain;
-  Chain = DAG.getCALLSEQ_START(Chain, UniqueCallSite, 0, dl);
-  SDValue InGlue = Chain.getValue(1);
+  const SDValue CallChain = CLI.Chain;
+  const SDValue StartChain =
+      DAG.getCALLSEQ_START(CallChain, UniqueCallSite, 0, dl);
+  SDValue DeclareGlue = StartChain.getValue(1);
+
+  SmallVector<SDValue, 16> CallPrereqs{StartChain};
+
+  const auto DeclareScalarParam = [&](SDValue Symbol, unsigned Size) {
+    // PTX ABI requires integral types to be at least 32 bits in size. FP16 is
+    // loaded/stored using i16, so it's handled here as well.
+    const unsigned SizeBits = promoteScalarArgumentSize(Size * 8);
+    SDValue Declare =
+        DAG.getNode(NVPTXISD::DeclareScalarParam, dl, {MVT::Other, MVT::Glue},
+                    {StartChain, Symbol, GetI32(SizeBits), DeclareGlue});
+    CallPrereqs.push_back(Declare);
+    DeclareGlue = Declare.getValue(1);
+    return Declare;
+  };
+
+  const auto DeclareArrayParam = [&](SDValue Symbol, Align Align,
+                                     unsigned Size) {
+    SDValue Declare = DAG.getNode(
+        NVPTXISD::DeclareArrayParam, dl, {MVT::Other, MVT::Glue},
+        {StartChain, Symbol, GetI32(Align.value()), GetI32(Size), DeclareGlue});
+    CallPrereqs.push_back(Declare);
+    DeclareGlue = Declare.getValue(1);
+    return Declare;
+  };
 
   // Args.size() and Outs.size() need not match.
   // Outs.size() will be larger
@@ -1555,43 +1473,23 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
     assert((!IsByVal || TypeSize == ArgOuts[0].Flags.getByValSize()) &&
            "type size mismatch");
 
-    const std::optional<SDValue> ArgDeclare = [&]() -> std::optional<SDValue> {
+    const SDValue ArgDeclare = [&]() {
       if (IsVAArg) {
-        if (ArgI == FirstVAArg) {
-          VADeclareParam = DAG.getNode(
-              NVPTXISD::DeclareArrayParam, dl, {MVT::Other, MVT::Glue},
-              {Chain, ParamSymbol, GetI32(STI.getMaxRequiredAlignment()),
-               GetI32(0), InGlue});
-          return VADeclareParam;
-        }
-        return std::nullopt;
-      }
-      if (IsByVal || shouldPassAsArray(Arg.Ty)) {
-        // declare .param .align <align> .b8 .param<n>[<size>];
-        return DAG.getNode(NVPTXISD::DeclareArrayParam, dl,
-                           {MVT::Other, MVT::Glue},
-                           {Chain, ParamSymbol, GetI32(ArgAlign.value()),
-                            GetI32(TypeSize), InGlue});
+        if (ArgI == FirstVAArg)
+          VADeclareParam = DeclareArrayParam(
+              ParamSymbol, Align(STI.getMaxRequiredAlignment()), 0);
+        return VADeclareParam;
       }
+
+      if (IsByVal || shouldPassAsArray(Arg.Ty))
+        return DeclareArrayParam(ParamSymbol, ArgAlign, TypeSize);
+
       assert(ArgOuts.size() == 1 && "We must pass only one value as non-array");
-      // declare .param .b<size> .param<n>;
-
-      // PTX ABI requ...
[truncated]

Copy link

github-actions bot commented Jul 3, 2025

✅ With the latest revision this PR passed the C/C++ code formatter.

@AlexMaclean AlexMaclean force-pushed the dev/amaclean/upstream/v2i8-tmp2 branch 3 times, most recently from bdc2f75 to a33744a Compare July 3, 2025 18:33
@llvmbot llvmbot added the clang Clang issues not falling into any other category label Jul 3, 2025
@AlexMaclean AlexMaclean force-pushed the dev/amaclean/upstream/v2i8-tmp2 branch from a33744a to 43cfe71 Compare July 3, 2025 21:08
@AlexMaclean AlexMaclean force-pushed the dev/amaclean/upstream/v2i8-tmp2 branch 2 times, most recently from ab348bc to 8b45f7d Compare July 8, 2025 18:19
Copy link
Member

@Artem-B Artem-B left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good overall.
Some style nits on lambda use in the code, and a couple of questions about observed changes in the generated code.

; CHECK-NEXT: st.param.b8 [param0+2], %r3;
; CHECK-NEXT: shr.u32 %r4, %r3, 8;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: The new code is technically worse than the old one, as now this shift has a data dependency on %r3, while previously all three shifts could be done independently. Probably does not matter much (ptxas may optimize it away, or the GPU may not run that may shifts in parallel anyways).

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unfortunately, I think this is a quirk of how generic unaligned stores are lowered and optimized that we'll just have to live with for now. Certainly something to investigate in the future but I don't think it makes sense to address here.

@AlexMaclean AlexMaclean force-pushed the dev/amaclean/upstream/v2i8-tmp2 branch 2 times, most recently from b43a097 to 8a7ed37 Compare July 9, 2025 19:47
@AlexMaclean AlexMaclean force-pushed the dev/amaclean/upstream/v2i8-tmp2 branch from 8a7ed37 to 444dcfa Compare July 23, 2025 17:58
@llvmbot llvmbot added the llvm:SelectionDAG SelectionDAGISel as well label Jul 23, 2025
@AlexMaclean AlexMaclean force-pushed the dev/amaclean/upstream/v2i8-tmp2 branch from 444dcfa to b966989 Compare July 26, 2025 15:36
@AlexMaclean AlexMaclean force-pushed the dev/amaclean/upstream/v2i8-tmp2 branch from b966989 to 15c4720 Compare July 28, 2025 16:38
@AlexMaclean AlexMaclean merged commit 35693da into llvm:main Jul 28, 2025
9 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
backend:NVPTX clang Clang issues not falling into any other category debuginfo llvm:SelectionDAG SelectionDAGISel as well
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants