Skip to content

[NVPTX] Eliminate prmts that result from BUILD_VECTOR of LoadV2 #149581

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 76 additions & 1 deletion llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5772,7 +5772,8 @@ static SDValue PerformVSELECTCombine(SDNode *N,
}

static SDValue
PerformBUILD_VECTORCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI) {
PerformBUILD_VECTOROfV2i16Combine(SDNode *N,
TargetLowering::DAGCombinerInfo &DCI) {
auto VT = N->getValueType(0);
if (!DCI.isAfterLegalizeDAG() ||
// only process v2*16 types
Expand Down Expand Up @@ -5833,6 +5834,80 @@ PerformBUILD_VECTORCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI) {
return DAG.getBitcast(VT, PRMT);
}

static SDValue
PerformBUILD_VECTOROfTargetLoadCombine(SDNode *N,
TargetLowering::DAGCombinerInfo &DCI) {
// Match: BUILD_VECTOR of v4i8, where first two elements are from a
// NVPTXISD::LoadV2 or NVPTXISD::LDUV2 of i8, and the last two elements are
// zero constants. Replace with: zext the loaded i16 to i32, and return as a
// bitcast to v4i8.
Comment on lines +5840 to +5843
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This appears to be an oddly specific pattern with two specific elements loaded, and two specific constants. Is that a particularly common pattern? Where does it come from, if so?

Copy link
Contributor Author

@justinfargnoli justinfargnoli Jul 24, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is that a particularly common pattern?

This is obviously subjective, but I would not say it's a "common" pattern.

Where does it come from, if so?

This is an artifact of type legalization converting vNi8 operations into v4i8. Godbolt Link

What we really want to optimize is the series of prmts that reconstruct the 16-bit load. This BUILD_VECTOR pattern is what generates the series of prmts in all cases that I'm aware of. Thus, as an implementation decision, I wrote this peephole for this very specific BUILD_VECTOR case.

(Previously, I wrote the code for this peephole by looking at the prmts and it's a little bit more involved.)

Comment on lines +5840 to +5843
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems like a very specific and uncommon occurrence and the test cases look a bit contrived. Is this something we're going to see in real programs?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The original use case for this optimization was to clean up after an internal pass.

However, the test cases in this PR are llvm-reduced IR from real CUDA programs that didn't utilize said internal pass.

If you think the usefulness isn't worth the maintenance cost, I'm happy to keep it internal.

EVT VT = N->getValueType(0);
if (VT != MVT::v4i8)
return SDValue();
// Check operands: [0]=lo, [1]=hi
SDValue Op0 = N->getOperand(0);
SDValue Op1 = N->getOperand(1);
// Check that Op0 and Op1 are from the same NVPTXISD::LoadV2 or
// NVPTXISD::LDUV2
if (Op0.getNode() != Op1.getNode())
Comment on lines +5850 to +5852
Copy link
Preview

Copilot AI Jul 18, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This condition checks if Op0 and Op1 come from the same node, but the comment and function logic suggest they should be consecutive elements from a LoadV2. Consider adding a comment explaining why the same node check is sufficient, or verify that Op0 and Op1 are specifically the low and high bytes of the same load.

Suggested change
// Check that Op0 and Op1 are from the same NVPTXISD::LoadV2 or
// NVPTXISD::LDUV2
if (Op0.getNode() != Op1.getNode())
// Check that Op0 and Op1 are consecutive elements (low and high bytes)
// from the same NVPTXISD::LoadV2 or NVPTXISD::LDUV2
if (Op0.getNode() != Op1.getNode() || Op0.getResNo() != 0 || Op1.getResNo() != 1)

Copilot uses AI. Check for mistakes.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bug: we need to check that they're different values from the same node.

return SDValue();
if (!(Op0.getOpcode() == NVPTXISD::LoadV2 ||
Op0.getOpcode() == NVPTXISD::LDUV2))
return SDValue();
if (Op0.getValueType() != MVT::i16)
return SDValue();
if (!(Op0.hasOneUse() && Op1.hasOneUse()))
return SDValue();

// Check operands: [2]= 0 or undef, [3]= 0 or undef
SDValue Op2 = N->getOperand(2);
SDValue Op3 = N->getOperand(3);
if (Op2 != Op3)
return SDValue();
if (!Op2.isUndef()) {
auto *C2 = dyn_cast<ConstantSDNode>(Op2);
if (!(C2 && C2->isZero()))
return SDValue();
}

// Now, replace with: zext(load i16) -> i32, then bitcast to v4i8
auto &DAG = DCI.DAG;
// Rebuild the load as i16
auto *Load = cast<MemSDNode>(Op0.getNode());
SDLoc DL(Load);
SDValue LoadI16;
if (Load->getOpcode() == NVPTXISD::LoadV2) {
LoadI16 = DAG.getLoad(MVT::i16, DL, Load->getChain(), Load->getBasePtr(),
Load->getPointerInfo(), Load->getAlign(),
Load->getMemOperand()->getFlags());
} else {
assert(Load->getOpcode() == NVPTXISD::LDUV2);
const TargetLowering &TLI = DAG.getTargetLoweringInfo();
SmallVector<SDValue, 4> Ops;
Ops.push_back(Load->getChain());
Ops.push_back(DAG.getConstant(Intrinsic::nvvm_ldu_global_i, DL,
TLI.getPointerTy(DAG.getDataLayout())));
for (unsigned i = 1; i < Load->getNumOperands(); ++i)
Ops.push_back(Load->getOperand(i));
SDVTList NodeVTList = DAG.getVTList(MVT::i16, MVT::Other);
LoadI16 = DAG.getMemIntrinsicNode(ISD::INTRINSIC_W_CHAIN, DL, NodeVTList,
Ops, MVT::i16, Load->getPointerInfo(),
Load->getAlign());
}
DAG.ReplaceAllUsesOfValueWith(SDValue(Load, 2), LoadI16.getValue(1));
SDValue Zext = DAG.getZExtOrTrunc(LoadI16, DL, MVT::i32);
return DAG.getBitcast(MVT::v4i8, Zext);
}

static SDValue
PerformBUILD_VECTORCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI) {
if (const auto V = PerformBUILD_VECTOROfV2i16Combine(N, DCI))
return V;
if (const auto V = PerformBUILD_VECTOROfTargetLoadCombine(N, DCI))
return V;
return SDValue();
}

static SDValue combineADDRSPACECAST(SDNode *N,
TargetLowering::DAGCombinerInfo &DCI) {
auto *ASCN1 = cast<AddrSpaceCastSDNode>(N);
Expand Down
106 changes: 106 additions & 0 deletions llvm/test/CodeGen/NVPTX/build-vector-combine.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
; RUN: llc < %s -march=nvptx64 | FileCheck %s
; RUN: %if ptxas %{ llc < %s -march=nvptx64 | %ptxas-verify %}

target datalayout = "e-p:64:64:64-p3:32:32:32-i1:8:8-i8:8:8-i16:16:16-i32:32:32-i64:64:64-i128:128:128-f32:32:32-f64:64:64-f128:128:128-v16:16:16-v32:32:32-v64:64:64-v128:128:128-n16:32:64-a:8:8"
target triple = "nvptx64-nvidia-cuda"

define void @t1() {
; CHECK-LABEL: t1(
; CHECK: {
; CHECK-NEXT: .reg .b32 %r<2>;
; CHECK-NEXT: .reg .b64 %rd<2>;
; CHECK-EMPTY:
; CHECK-NEXT: // %bb.0: // %entry
; CHECK-NEXT: mov.b64 %rd1, 0;
; CHECK-NEXT: ld.global.b16 %r1, [%rd1];
; CHECK-NEXT: st.global.v4.b32 [%rd1], {%r1, 0, 0, 0};
; CHECK-NEXT: ret;
entry:
%0 = load <2 x i8>, ptr addrspace(1) null, align 4
%1 = shufflevector <2 x i8> %0, <2 x i8> zeroinitializer, <4 x i32> <i32 0, i32 1, i32 2, i32 3>
%2 = bitcast <4 x i8> %1 to i32
%3 = insertelement <4 x i32> zeroinitializer, i32 %2, i64 0
store <4 x i32> %3, ptr addrspace(1) null, align 16
ret void
}

define void @t2() {
; CHECK-LABEL: t2(
; CHECK: {
; CHECK-NEXT: .reg .b32 %r<2>;
; CHECK-NEXT: .reg .b64 %rd<2>;
; CHECK-EMPTY:
; CHECK-NEXT: // %bb.0: // %entry
; CHECK-NEXT: mov.b64 %rd1, 0;
; CHECK-NEXT: ld.global.b16 %r1, [%rd1];
; CHECK-NEXT: st.local.b32 [%rd1], %r1;
; CHECK-NEXT: ret;
entry:
%0 = load <2 x i8>, ptr addrspace(1) null, align 8
%1 = shufflevector <2 x i8> %0, <2 x i8> zeroinitializer, <4 x i32> <i32 0, i32 1, i32 2, i32 3>
store <4 x i8> %1, ptr addrspace(5) null, align 8
ret void
}

declare <2 x i8> @llvm.nvvm.ldg.global.i.v2i8.p1(ptr addrspace(1) %ptr, i32 %align)

define void @ldg(ptr addrspace(1) %ptr) {
; CHECK-LABEL: ldg(
; CHECK: {
; CHECK-NEXT: .reg .b32 %r<2>;
; CHECK-NEXT: .reg .b64 %rd<3>;
; CHECK-EMPTY:
; CHECK-NEXT: // %bb.0: // %entry
; CHECK-NEXT: ld.param.b64 %rd1, [ldg_param_0];
; CHECK-NEXT: ld.global.b16 %r1, [%rd1];
; CHECK-NEXT: mov.b64 %rd2, 0;
; CHECK-NEXT: st.local.b32 [%rd2], %r1;
; CHECK-NEXT: ret;
entry:
%0 = tail call <2 x i8> @llvm.nvvm.ldg.global.i.v2i8.p1(ptr addrspace(1) %ptr, i32 2)
%1 = shufflevector <2 x i8> %0, <2 x i8> zeroinitializer, <4 x i32> <i32 0, i32 1, i32 2, i32 3>
store <4 x i8> %1, ptr addrspace(5) null, align 8
ret void
}

declare <2 x i8> @llvm.nvvm.ldu.global.f.v2i8.p1(ptr addrspace(1) %ptr, i32 %align)

define void @ldu(ptr addrspace(1) %ptr) {
; CHECK-LABEL: ldu(
; CHECK: {
; CHECK-NEXT: .reg .b16 %rs<2>;
; CHECK-NEXT: .reg .b32 %r<2>;
; CHECK-NEXT: .reg .b64 %rd<3>;
; CHECK-EMPTY:
; CHECK-NEXT: // %bb.0: // %entry
; CHECK-NEXT: ld.param.b64 %rd1, [ldu_param_0];
; CHECK-NEXT: ldu.global.b16 %rs1, [%rd1];
; CHECK-NEXT: cvt.u32.u16 %r1, %rs1;
; CHECK-NEXT: mov.b64 %rd2, 0;
; CHECK-NEXT: st.local.b32 [%rd2], %r1;
; CHECK-NEXT: ret;
entry:
%0 = tail call <2 x i8> @llvm.nvvm.ldu.global.i.v2i8.p1(ptr addrspace(1) %ptr, i32 2)
%1 = shufflevector <2 x i8> %0, <2 x i8> zeroinitializer, <4 x i32> <i32 0, i32 1, i32 2, i32 3>
store <4 x i8> %1, ptr addrspace(5) null, align 8
ret void
}

define void @t3() {
; CHECK-LABEL: t3(
; CHECK: {
; CHECK-NEXT: .reg .b32 %r<2>;
; CHECK-NEXT: .reg .b64 %rd<2>;
; CHECK-EMPTY:
; CHECK-NEXT: // %bb.0:
; CHECK-NEXT: mov.b64 %rd1, 0;
; CHECK-NEXT: ld.global.b16 %r1, [%rd1];
; CHECK-NEXT: st.global.v2.b32 [%rd1], {%r1, 0};
; CHECK-NEXT: ret;
%1 = load <2 x i8>, ptr addrspace(1) null, align 2
%insval2 = bitcast <2 x i8> %1 to i16
%2 = insertelement <4 x i16> zeroinitializer, i16 %insval2, i32 0
store <4 x i16> %2, ptr addrspace(1) null, align 8
ret void
}
Loading