-
Notifications
You must be signed in to change notification settings - Fork 14.7k
[NVPTX] Eliminate prmt
s that result from BUILD_VECTOR
of LoadV2
#149581
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -5772,7 +5772,8 @@ static SDValue PerformVSELECTCombine(SDNode *N, | |||||||||||||
} | ||||||||||||||
|
||||||||||||||
static SDValue | ||||||||||||||
PerformBUILD_VECTORCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI) { | ||||||||||||||
PerformBUILD_VECTOROfV2i16Combine(SDNode *N, | ||||||||||||||
TargetLowering::DAGCombinerInfo &DCI) { | ||||||||||||||
auto VT = N->getValueType(0); | ||||||||||||||
if (!DCI.isAfterLegalizeDAG() || | ||||||||||||||
// only process v2*16 types | ||||||||||||||
|
@@ -5833,6 +5834,80 @@ PerformBUILD_VECTORCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI) { | |||||||||||||
return DAG.getBitcast(VT, PRMT); | ||||||||||||||
} | ||||||||||||||
|
||||||||||||||
static SDValue | ||||||||||||||
PerformBUILD_VECTOROfTargetLoadCombine(SDNode *N, | ||||||||||||||
TargetLowering::DAGCombinerInfo &DCI) { | ||||||||||||||
// Match: BUILD_VECTOR of v4i8, where first two elements are from a | ||||||||||||||
// NVPTXISD::LoadV2 or NVPTXISD::LDUV2 of i8, and the last two elements are | ||||||||||||||
// zero constants. Replace with: zext the loaded i16 to i32, and return as a | ||||||||||||||
// bitcast to v4i8. | ||||||||||||||
Comment on lines
+5840
to
+5843
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This seems like a very specific and uncommon occurrence and the test cases look a bit contrived. Is this something we're going to see in real programs? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The original use case for this optimization was to clean up after an internal pass. However, the test cases in this PR are If you think the usefulness isn't worth the maintenance cost, I'm happy to keep it internal. |
||||||||||||||
EVT VT = N->getValueType(0); | ||||||||||||||
if (VT != MVT::v4i8) | ||||||||||||||
return SDValue(); | ||||||||||||||
// Check operands: [0]=lo, [1]=hi | ||||||||||||||
SDValue Op0 = N->getOperand(0); | ||||||||||||||
SDValue Op1 = N->getOperand(1); | ||||||||||||||
// Check that Op0 and Op1 are from the same NVPTXISD::LoadV2 or | ||||||||||||||
// NVPTXISD::LDUV2 | ||||||||||||||
if (Op0.getNode() != Op1.getNode()) | ||||||||||||||
Comment on lines
+5850
to
+5852
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This condition checks if Op0 and Op1 come from the same node, but the comment and function logic suggest they should be consecutive elements from a LoadV2. Consider adding a comment explaining why the same node check is sufficient, or verify that Op0 and Op1 are specifically the low and high bytes of the same load.
Suggested change
Copilot uses AI. Check for mistakes. Positive FeedbackNegative Feedback There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Bug: we need to check that they're different values from the same node. |
||||||||||||||
return SDValue(); | ||||||||||||||
if (!(Op0.getOpcode() == NVPTXISD::LoadV2 || | ||||||||||||||
Op0.getOpcode() == NVPTXISD::LDUV2)) | ||||||||||||||
return SDValue(); | ||||||||||||||
if (Op0.getValueType() != MVT::i16) | ||||||||||||||
return SDValue(); | ||||||||||||||
if (!(Op0.hasOneUse() && Op1.hasOneUse())) | ||||||||||||||
return SDValue(); | ||||||||||||||
|
||||||||||||||
// Check operands: [2]= 0 or undef, [3]= 0 or undef | ||||||||||||||
SDValue Op2 = N->getOperand(2); | ||||||||||||||
SDValue Op3 = N->getOperand(3); | ||||||||||||||
if (Op2 != Op3) | ||||||||||||||
return SDValue(); | ||||||||||||||
if (!Op2.isUndef()) { | ||||||||||||||
auto *C2 = dyn_cast<ConstantSDNode>(Op2); | ||||||||||||||
if (!(C2 && C2->isZero())) | ||||||||||||||
return SDValue(); | ||||||||||||||
} | ||||||||||||||
|
||||||||||||||
// Now, replace with: zext(load i16) -> i32, then bitcast to v4i8 | ||||||||||||||
auto &DAG = DCI.DAG; | ||||||||||||||
// Rebuild the load as i16 | ||||||||||||||
auto *Load = cast<MemSDNode>(Op0.getNode()); | ||||||||||||||
SDLoc DL(Load); | ||||||||||||||
SDValue LoadI16; | ||||||||||||||
if (Load->getOpcode() == NVPTXISD::LoadV2) { | ||||||||||||||
LoadI16 = DAG.getLoad(MVT::i16, DL, Load->getChain(), Load->getBasePtr(), | ||||||||||||||
Load->getPointerInfo(), Load->getAlign(), | ||||||||||||||
Load->getMemOperand()->getFlags()); | ||||||||||||||
} else { | ||||||||||||||
assert(Load->getOpcode() == NVPTXISD::LDUV2); | ||||||||||||||
const TargetLowering &TLI = DAG.getTargetLoweringInfo(); | ||||||||||||||
SmallVector<SDValue, 4> Ops; | ||||||||||||||
Ops.push_back(Load->getChain()); | ||||||||||||||
Ops.push_back(DAG.getConstant(Intrinsic::nvvm_ldu_global_i, DL, | ||||||||||||||
TLI.getPointerTy(DAG.getDataLayout()))); | ||||||||||||||
for (unsigned i = 1; i < Load->getNumOperands(); ++i) | ||||||||||||||
Ops.push_back(Load->getOperand(i)); | ||||||||||||||
SDVTList NodeVTList = DAG.getVTList(MVT::i16, MVT::Other); | ||||||||||||||
LoadI16 = DAG.getMemIntrinsicNode(ISD::INTRINSIC_W_CHAIN, DL, NodeVTList, | ||||||||||||||
Ops, MVT::i16, Load->getPointerInfo(), | ||||||||||||||
Load->getAlign()); | ||||||||||||||
} | ||||||||||||||
DAG.ReplaceAllUsesOfValueWith(SDValue(Load, 2), LoadI16.getValue(1)); | ||||||||||||||
justinfargnoli marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||||||||
SDValue Zext = DAG.getZExtOrTrunc(LoadI16, DL, MVT::i32); | ||||||||||||||
return DAG.getBitcast(MVT::v4i8, Zext); | ||||||||||||||
} | ||||||||||||||
|
||||||||||||||
static SDValue | ||||||||||||||
PerformBUILD_VECTORCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI) { | ||||||||||||||
if (const auto V = PerformBUILD_VECTOROfV2i16Combine(N, DCI)) | ||||||||||||||
return V; | ||||||||||||||
if (const auto V = PerformBUILD_VECTOROfTargetLoadCombine(N, DCI)) | ||||||||||||||
return V; | ||||||||||||||
return SDValue(); | ||||||||||||||
} | ||||||||||||||
|
||||||||||||||
static SDValue combineADDRSPACECAST(SDNode *N, | ||||||||||||||
TargetLowering::DAGCombinerInfo &DCI) { | ||||||||||||||
auto *ASCN1 = cast<AddrSpaceCastSDNode>(N); | ||||||||||||||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,106 @@ | ||
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5 | ||
; RUN: llc < %s -march=nvptx64 | FileCheck %s | ||
; RUN: %if ptxas %{ llc < %s -march=nvptx64 | %ptxas-verify %} | ||
|
||
target datalayout = "e-p:64:64:64-p3:32:32:32-i1:8:8-i8:8:8-i16:16:16-i32:32:32-i64:64:64-i128:128:128-f32:32:32-f64:64:64-f128:128:128-v16:16:16-v32:32:32-v64:64:64-v128:128:128-n16:32:64-a:8:8" | ||
target triple = "nvptx64-nvidia-cuda" | ||
|
||
define void @t1() { | ||
; CHECK-LABEL: t1( | ||
; CHECK: { | ||
; CHECK-NEXT: .reg .b32 %r<2>; | ||
; CHECK-NEXT: .reg .b64 %rd<2>; | ||
; CHECK-EMPTY: | ||
; CHECK-NEXT: // %bb.0: // %entry | ||
; CHECK-NEXT: mov.b64 %rd1, 0; | ||
; CHECK-NEXT: ld.global.b16 %r1, [%rd1]; | ||
; CHECK-NEXT: st.global.v4.b32 [%rd1], {%r1, 0, 0, 0}; | ||
; CHECK-NEXT: ret; | ||
entry: | ||
%0 = load <2 x i8>, ptr addrspace(1) null, align 4 | ||
%1 = shufflevector <2 x i8> %0, <2 x i8> zeroinitializer, <4 x i32> <i32 0, i32 1, i32 2, i32 3> | ||
%2 = bitcast <4 x i8> %1 to i32 | ||
%3 = insertelement <4 x i32> zeroinitializer, i32 %2, i64 0 | ||
store <4 x i32> %3, ptr addrspace(1) null, align 16 | ||
ret void | ||
} | ||
|
||
define void @t2() { | ||
; CHECK-LABEL: t2( | ||
; CHECK: { | ||
; CHECK-NEXT: .reg .b32 %r<2>; | ||
; CHECK-NEXT: .reg .b64 %rd<2>; | ||
; CHECK-EMPTY: | ||
; CHECK-NEXT: // %bb.0: // %entry | ||
; CHECK-NEXT: mov.b64 %rd1, 0; | ||
; CHECK-NEXT: ld.global.b16 %r1, [%rd1]; | ||
; CHECK-NEXT: st.local.b32 [%rd1], %r1; | ||
; CHECK-NEXT: ret; | ||
entry: | ||
%0 = load <2 x i8>, ptr addrspace(1) null, align 8 | ||
%1 = shufflevector <2 x i8> %0, <2 x i8> zeroinitializer, <4 x i32> <i32 0, i32 1, i32 2, i32 3> | ||
store <4 x i8> %1, ptr addrspace(5) null, align 8 | ||
ret void | ||
} | ||
|
||
declare <2 x i8> @llvm.nvvm.ldg.global.i.v2i8.p1(ptr addrspace(1) %ptr, i32 %align) | ||
|
||
define void @ldg(ptr addrspace(1) %ptr) { | ||
; CHECK-LABEL: ldg( | ||
; CHECK: { | ||
; CHECK-NEXT: .reg .b32 %r<2>; | ||
; CHECK-NEXT: .reg .b64 %rd<3>; | ||
; CHECK-EMPTY: | ||
; CHECK-NEXT: // %bb.0: // %entry | ||
; CHECK-NEXT: ld.param.b64 %rd1, [ldg_param_0]; | ||
; CHECK-NEXT: ld.global.b16 %r1, [%rd1]; | ||
; CHECK-NEXT: mov.b64 %rd2, 0; | ||
; CHECK-NEXT: st.local.b32 [%rd2], %r1; | ||
; CHECK-NEXT: ret; | ||
entry: | ||
%0 = tail call <2 x i8> @llvm.nvvm.ldg.global.i.v2i8.p1(ptr addrspace(1) %ptr, i32 2) | ||
%1 = shufflevector <2 x i8> %0, <2 x i8> zeroinitializer, <4 x i32> <i32 0, i32 1, i32 2, i32 3> | ||
store <4 x i8> %1, ptr addrspace(5) null, align 8 | ||
ret void | ||
} | ||
|
||
declare <2 x i8> @llvm.nvvm.ldu.global.f.v2i8.p1(ptr addrspace(1) %ptr, i32 %align) | ||
|
||
define void @ldu(ptr addrspace(1) %ptr) { | ||
; CHECK-LABEL: ldu( | ||
; CHECK: { | ||
; CHECK-NEXT: .reg .b16 %rs<2>; | ||
; CHECK-NEXT: .reg .b32 %r<2>; | ||
; CHECK-NEXT: .reg .b64 %rd<3>; | ||
; CHECK-EMPTY: | ||
; CHECK-NEXT: // %bb.0: // %entry | ||
; CHECK-NEXT: ld.param.b64 %rd1, [ldu_param_0]; | ||
; CHECK-NEXT: ldu.global.b16 %rs1, [%rd1]; | ||
; CHECK-NEXT: cvt.u32.u16 %r1, %rs1; | ||
; CHECK-NEXT: mov.b64 %rd2, 0; | ||
; CHECK-NEXT: st.local.b32 [%rd2], %r1; | ||
; CHECK-NEXT: ret; | ||
entry: | ||
%0 = tail call <2 x i8> @llvm.nvvm.ldu.global.i.v2i8.p1(ptr addrspace(1) %ptr, i32 2) | ||
%1 = shufflevector <2 x i8> %0, <2 x i8> zeroinitializer, <4 x i32> <i32 0, i32 1, i32 2, i32 3> | ||
store <4 x i8> %1, ptr addrspace(5) null, align 8 | ||
ret void | ||
} | ||
|
||
define void @t3() { | ||
; CHECK-LABEL: t3( | ||
; CHECK: { | ||
; CHECK-NEXT: .reg .b32 %r<2>; | ||
; CHECK-NEXT: .reg .b64 %rd<2>; | ||
; CHECK-EMPTY: | ||
; CHECK-NEXT: // %bb.0: | ||
; CHECK-NEXT: mov.b64 %rd1, 0; | ||
; CHECK-NEXT: ld.global.b16 %r1, [%rd1]; | ||
; CHECK-NEXT: st.global.v2.b32 [%rd1], {%r1, 0}; | ||
; CHECK-NEXT: ret; | ||
%1 = load <2 x i8>, ptr addrspace(1) null, align 2 | ||
%insval2 = bitcast <2 x i8> %1 to i16 | ||
%2 = insertelement <4 x i16> zeroinitializer, i16 %insval2, i32 0 | ||
store <4 x i16> %2, ptr addrspace(1) null, align 8 | ||
ret void | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This appears to be an oddly specific pattern with two specific elements loaded, and two specific constants. Is that a particularly common pattern? Where does it come from, if so?
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is obviously subjective, but I would not say it's a "common" pattern.
This is an artifact of type legalization converting
vNi8
operations intov4i8
. Godbolt LinkWhat we really want to optimize is the series of
prmt
s that reconstruct the 16-bit load. ThisBUILD_VECTOR
pattern is what generates the series ofprmt
s in all cases that I'm aware of. Thus, as an implementation decision, I wrote this peephole for this very specificBUILD_VECTOR
case.(Previously, I wrote the code for this peephole by looking at the
prmt
s and it's a little bit more involved.)