diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp index 7aa06f9079b09..5f98b1a27617d 100644 --- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp +++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp @@ -5772,7 +5772,8 @@ static SDValue PerformVSELECTCombine(SDNode *N, } static SDValue -PerformBUILD_VECTORCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI) { +PerformBUILD_VECTOROfV2i16Combine(SDNode *N, + TargetLowering::DAGCombinerInfo &DCI) { auto VT = N->getValueType(0); if (!DCI.isAfterLegalizeDAG() || // only process v2*16 types @@ -5833,6 +5834,80 @@ PerformBUILD_VECTORCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI) { return DAG.getBitcast(VT, PRMT); } +static SDValue +PerformBUILD_VECTOROfTargetLoadCombine(SDNode *N, + TargetLowering::DAGCombinerInfo &DCI) { + // Match: BUILD_VECTOR of v4i8, where first two elements are from a + // NVPTXISD::LoadV2 or NVPTXISD::LDUV2 of i8, and the last two elements are + // zero constants. Replace with: zext the loaded i16 to i32, and return as a + // bitcast to v4i8. + EVT VT = N->getValueType(0); + if (VT != MVT::v4i8) + return SDValue(); + // Check operands: [0]=lo, [1]=hi + SDValue Op0 = N->getOperand(0); + SDValue Op1 = N->getOperand(1); + // Check that Op0 and Op1 are from the same NVPTXISD::LoadV2 or + // NVPTXISD::LDUV2 + if (Op0.getNode() != Op1.getNode()) + return SDValue(); + if (!(Op0.getOpcode() == NVPTXISD::LoadV2 || + Op0.getOpcode() == NVPTXISD::LDUV2)) + return SDValue(); + if (Op0.getValueType() != MVT::i16) + return SDValue(); + if (!(Op0.hasOneUse() && Op1.hasOneUse())) + return SDValue(); + + // Check operands: [2]= 0 or undef, [3]= 0 or undef + SDValue Op2 = N->getOperand(2); + SDValue Op3 = N->getOperand(3); + if (Op2 != Op3) + return SDValue(); + if (!Op2.isUndef()) { + auto *C2 = dyn_cast(Op2); + if (!(C2 && C2->isZero())) + return SDValue(); + } + + // Now, replace with: zext(load i16) -> i32, then bitcast to v4i8 + auto &DAG = DCI.DAG; + // Rebuild the load as i16 + auto *Load = cast(Op0.getNode()); + SDLoc DL(Load); + SDValue LoadI16; + if (Load->getOpcode() == NVPTXISD::LoadV2) { + LoadI16 = DAG.getLoad(MVT::i16, DL, Load->getChain(), Load->getBasePtr(), + Load->getPointerInfo(), Load->getAlign(), + Load->getMemOperand()->getFlags()); + } else { + assert(Load->getOpcode() == NVPTXISD::LDUV2); + const TargetLowering &TLI = DAG.getTargetLoweringInfo(); + SmallVector Ops; + Ops.push_back(Load->getChain()); + Ops.push_back(DAG.getConstant(Intrinsic::nvvm_ldu_global_i, DL, + TLI.getPointerTy(DAG.getDataLayout()))); + for (unsigned i = 1; i < Load->getNumOperands(); ++i) + Ops.push_back(Load->getOperand(i)); + SDVTList NodeVTList = DAG.getVTList(MVT::i16, MVT::Other); + LoadI16 = DAG.getMemIntrinsicNode(ISD::INTRINSIC_W_CHAIN, DL, NodeVTList, + Ops, MVT::i16, Load->getPointerInfo(), + Load->getAlign()); + } + DAG.ReplaceAllUsesOfValueWith(SDValue(Load, 2), LoadI16.getValue(1)); + SDValue Zext = DAG.getZExtOrTrunc(LoadI16, DL, MVT::i32); + return DAG.getBitcast(MVT::v4i8, Zext); +} + +static SDValue +PerformBUILD_VECTORCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI) { + if (const auto V = PerformBUILD_VECTOROfV2i16Combine(N, DCI)) + return V; + if (const auto V = PerformBUILD_VECTOROfTargetLoadCombine(N, DCI)) + return V; + return SDValue(); +} + static SDValue combineADDRSPACECAST(SDNode *N, TargetLowering::DAGCombinerInfo &DCI) { auto *ASCN1 = cast(N); diff --git a/llvm/test/CodeGen/NVPTX/build-vector-combine.ll b/llvm/test/CodeGen/NVPTX/build-vector-combine.ll new file mode 100644 index 0000000000000..019bd3bde8761 --- /dev/null +++ b/llvm/test/CodeGen/NVPTX/build-vector-combine.ll @@ -0,0 +1,106 @@ +; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5 +; RUN: llc < %s -march=nvptx64 | FileCheck %s +; RUN: %if ptxas %{ llc < %s -march=nvptx64 | %ptxas-verify %} + +target datalayout = "e-p:64:64:64-p3:32:32:32-i1:8:8-i8:8:8-i16:16:16-i32:32:32-i64:64:64-i128:128:128-f32:32:32-f64:64:64-f128:128:128-v16:16:16-v32:32:32-v64:64:64-v128:128:128-n16:32:64-a:8:8" +target triple = "nvptx64-nvidia-cuda" + +define void @t1() { +; CHECK-LABEL: t1( +; CHECK: { +; CHECK-NEXT: .reg .b32 %r<2>; +; CHECK-NEXT: .reg .b64 %rd<2>; +; CHECK-EMPTY: +; CHECK-NEXT: // %bb.0: // %entry +; CHECK-NEXT: mov.b64 %rd1, 0; +; CHECK-NEXT: ld.global.b16 %r1, [%rd1]; +; CHECK-NEXT: st.global.v4.b32 [%rd1], {%r1, 0, 0, 0}; +; CHECK-NEXT: ret; +entry: + %0 = load <2 x i8>, ptr addrspace(1) null, align 4 + %1 = shufflevector <2 x i8> %0, <2 x i8> zeroinitializer, <4 x i32> + %2 = bitcast <4 x i8> %1 to i32 + %3 = insertelement <4 x i32> zeroinitializer, i32 %2, i64 0 + store <4 x i32> %3, ptr addrspace(1) null, align 16 + ret void +} + +define void @t2() { +; CHECK-LABEL: t2( +; CHECK: { +; CHECK-NEXT: .reg .b32 %r<2>; +; CHECK-NEXT: .reg .b64 %rd<2>; +; CHECK-EMPTY: +; CHECK-NEXT: // %bb.0: // %entry +; CHECK-NEXT: mov.b64 %rd1, 0; +; CHECK-NEXT: ld.global.b16 %r1, [%rd1]; +; CHECK-NEXT: st.local.b32 [%rd1], %r1; +; CHECK-NEXT: ret; +entry: + %0 = load <2 x i8>, ptr addrspace(1) null, align 8 + %1 = shufflevector <2 x i8> %0, <2 x i8> zeroinitializer, <4 x i32> + store <4 x i8> %1, ptr addrspace(5) null, align 8 + ret void +} + +declare <2 x i8> @llvm.nvvm.ldg.global.i.v2i8.p1(ptr addrspace(1) %ptr, i32 %align) + +define void @ldg(ptr addrspace(1) %ptr) { +; CHECK-LABEL: ldg( +; CHECK: { +; CHECK-NEXT: .reg .b32 %r<2>; +; CHECK-NEXT: .reg .b64 %rd<3>; +; CHECK-EMPTY: +; CHECK-NEXT: // %bb.0: // %entry +; CHECK-NEXT: ld.param.b64 %rd1, [ldg_param_0]; +; CHECK-NEXT: ld.global.b16 %r1, [%rd1]; +; CHECK-NEXT: mov.b64 %rd2, 0; +; CHECK-NEXT: st.local.b32 [%rd2], %r1; +; CHECK-NEXT: ret; +entry: + %0 = tail call <2 x i8> @llvm.nvvm.ldg.global.i.v2i8.p1(ptr addrspace(1) %ptr, i32 2) + %1 = shufflevector <2 x i8> %0, <2 x i8> zeroinitializer, <4 x i32> + store <4 x i8> %1, ptr addrspace(5) null, align 8 + ret void +} + +declare <2 x i8> @llvm.nvvm.ldu.global.f.v2i8.p1(ptr addrspace(1) %ptr, i32 %align) + +define void @ldu(ptr addrspace(1) %ptr) { +; CHECK-LABEL: ldu( +; CHECK: { +; CHECK-NEXT: .reg .b16 %rs<2>; +; CHECK-NEXT: .reg .b32 %r<2>; +; CHECK-NEXT: .reg .b64 %rd<3>; +; CHECK-EMPTY: +; CHECK-NEXT: // %bb.0: // %entry +; CHECK-NEXT: ld.param.b64 %rd1, [ldu_param_0]; +; CHECK-NEXT: ldu.global.b16 %rs1, [%rd1]; +; CHECK-NEXT: cvt.u32.u16 %r1, %rs1; +; CHECK-NEXT: mov.b64 %rd2, 0; +; CHECK-NEXT: st.local.b32 [%rd2], %r1; +; CHECK-NEXT: ret; +entry: + %0 = tail call <2 x i8> @llvm.nvvm.ldu.global.i.v2i8.p1(ptr addrspace(1) %ptr, i32 2) + %1 = shufflevector <2 x i8> %0, <2 x i8> zeroinitializer, <4 x i32> + store <4 x i8> %1, ptr addrspace(5) null, align 8 + ret void +} + +define void @t3() { +; CHECK-LABEL: t3( +; CHECK: { +; CHECK-NEXT: .reg .b32 %r<2>; +; CHECK-NEXT: .reg .b64 %rd<2>; +; CHECK-EMPTY: +; CHECK-NEXT: // %bb.0: +; CHECK-NEXT: mov.b64 %rd1, 0; +; CHECK-NEXT: ld.global.b16 %r1, [%rd1]; +; CHECK-NEXT: st.global.v2.b32 [%rd1], {%r1, 0}; +; CHECK-NEXT: ret; + %1 = load <2 x i8>, ptr addrspace(1) null, align 2 + %insval2 = bitcast <2 x i8> %1 to i16 + %2 = insertelement <4 x i16> zeroinitializer, i16 %insval2, i32 0 + store <4 x i16> %2, ptr addrspace(1) null, align 8 + ret void +}