Skip to content

Commit 401e72c

Browse files
authored
[RISCV] Add intrinsics for strided segment loads with fixed vectors (#151611)
These intrinsics are the strided version of `llvm.riscv.segN.load` intrinsics.
1 parent 6c072c0 commit 401e72c

File tree

3 files changed

+189
-50
lines changed

3 files changed

+189
-50
lines changed

llvm/include/llvm/IR/IntrinsicsRISCV.td

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1717,6 +1717,16 @@ let TargetPrefix = "riscv" in {
17171717
llvm_anyint_ty],
17181718
[NoCapture<ArgIndex<0>>, IntrReadMem]>;
17191719

1720+
// Input: (pointer, offset, mask, vl)
1721+
def int_riscv_sseg # nf # _load_mask
1722+
: DefaultAttrsIntrinsic<!listconcat([llvm_anyvector_ty],
1723+
!listsplat(LLVMMatchType<0>,
1724+
!add(nf, -1))),
1725+
[llvm_anyptr_ty, llvm_anyint_ty,
1726+
LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>,
1727+
llvm_anyint_ty],
1728+
[NoCapture<ArgIndex<0>>, IntrReadMem]>;
1729+
17201730
// Input: (<stored values>..., pointer, mask, vl)
17211731
def int_riscv_seg # nf # _store_mask
17221732
: DefaultAttrsIntrinsic<[],

llvm/lib/Target/RISCV/RISCVISelLowering.cpp

Lines changed: 107 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -1819,6 +1819,13 @@ bool RISCVTargetLowering::getTgtMemIntrinsic(IntrinsicInfo &Info,
18191819
case Intrinsic::riscv_seg6_load_mask:
18201820
case Intrinsic::riscv_seg7_load_mask:
18211821
case Intrinsic::riscv_seg8_load_mask:
1822+
case Intrinsic::riscv_sseg2_load_mask:
1823+
case Intrinsic::riscv_sseg3_load_mask:
1824+
case Intrinsic::riscv_sseg4_load_mask:
1825+
case Intrinsic::riscv_sseg5_load_mask:
1826+
case Intrinsic::riscv_sseg6_load_mask:
1827+
case Intrinsic::riscv_sseg7_load_mask:
1828+
case Intrinsic::riscv_sseg8_load_mask:
18221829
return SetRVVLoadStoreInfo(/*PtrOp*/ 0, /*IsStore*/ false,
18231830
/*IsUnitStrided*/ false, /*UsePtrVal*/ true);
18241831
case Intrinsic::riscv_seg2_store_mask:
@@ -10938,6 +10945,97 @@ static inline SDValue getVCIXISDNodeVOID(SDValue &Op, SelectionDAG &DAG,
1093810945
return DAG.getNode(Type, SDLoc(Op), Op.getValueType(), Operands);
1093910946
}
1094010947

10948+
static SDValue
10949+
lowerFixedVectorSegLoadIntrinsics(unsigned IntNo, SDValue Op,
10950+
const RISCVSubtarget &Subtarget,
10951+
SelectionDAG &DAG) {
10952+
bool IsStrided;
10953+
switch (IntNo) {
10954+
case Intrinsic::riscv_seg2_load_mask:
10955+
case Intrinsic::riscv_seg3_load_mask:
10956+
case Intrinsic::riscv_seg4_load_mask:
10957+
case Intrinsic::riscv_seg5_load_mask:
10958+
case Intrinsic::riscv_seg6_load_mask:
10959+
case Intrinsic::riscv_seg7_load_mask:
10960+
case Intrinsic::riscv_seg8_load_mask:
10961+
IsStrided = false;
10962+
break;
10963+
case Intrinsic::riscv_sseg2_load_mask:
10964+
case Intrinsic::riscv_sseg3_load_mask:
10965+
case Intrinsic::riscv_sseg4_load_mask:
10966+
case Intrinsic::riscv_sseg5_load_mask:
10967+
case Intrinsic::riscv_sseg6_load_mask:
10968+
case Intrinsic::riscv_sseg7_load_mask:
10969+
case Intrinsic::riscv_sseg8_load_mask:
10970+
IsStrided = true;
10971+
break;
10972+
default:
10973+
llvm_unreachable("unexpected intrinsic ID");
10974+
};
10975+
10976+
static const Intrinsic::ID VlsegInts[7] = {
10977+
Intrinsic::riscv_vlseg2_mask, Intrinsic::riscv_vlseg3_mask,
10978+
Intrinsic::riscv_vlseg4_mask, Intrinsic::riscv_vlseg5_mask,
10979+
Intrinsic::riscv_vlseg6_mask, Intrinsic::riscv_vlseg7_mask,
10980+
Intrinsic::riscv_vlseg8_mask};
10981+
static const Intrinsic::ID VlssegInts[7] = {
10982+
Intrinsic::riscv_vlsseg2_mask, Intrinsic::riscv_vlsseg3_mask,
10983+
Intrinsic::riscv_vlsseg4_mask, Intrinsic::riscv_vlsseg5_mask,
10984+
Intrinsic::riscv_vlsseg6_mask, Intrinsic::riscv_vlsseg7_mask,
10985+
Intrinsic::riscv_vlsseg8_mask};
10986+
10987+
SDLoc DL(Op);
10988+
unsigned NF = Op->getNumValues() - 1;
10989+
assert(NF >= 2 && NF <= 8 && "Unexpected seg number");
10990+
MVT XLenVT = Subtarget.getXLenVT();
10991+
MVT VT = Op->getSimpleValueType(0);
10992+
MVT ContainerVT = ::getContainerForFixedLengthVector(DAG, VT, Subtarget);
10993+
unsigned Sz = NF * ContainerVT.getVectorMinNumElements() *
10994+
ContainerVT.getScalarSizeInBits();
10995+
EVT VecTupTy = MVT::getRISCVVectorTupleVT(Sz, NF);
10996+
10997+
// Operands: (chain, int_id, pointer, mask, vl) or
10998+
// (chain, int_id, pointer, offset, mask, vl)
10999+
SDValue VL = Op.getOperand(Op.getNumOperands() - 1);
11000+
SDValue Mask = Op.getOperand(Op.getNumOperands() - 2);
11001+
MVT MaskVT = Mask.getSimpleValueType();
11002+
MVT MaskContainerVT =
11003+
::getContainerForFixedLengthVector(DAG, MaskVT, Subtarget);
11004+
Mask = convertToScalableVector(MaskContainerVT, Mask, DAG, Subtarget);
11005+
11006+
SDValue IntID = DAG.getTargetConstant(
11007+
IsStrided ? VlssegInts[NF - 2] : VlsegInts[NF - 2], DL, XLenVT);
11008+
auto *Load = cast<MemIntrinsicSDNode>(Op);
11009+
11010+
SDVTList VTs = DAG.getVTList({VecTupTy, MVT::Other});
11011+
SmallVector<SDValue, 9> Ops = {
11012+
Load->getChain(),
11013+
IntID,
11014+
DAG.getUNDEF(VecTupTy),
11015+
Op.getOperand(2),
11016+
Mask,
11017+
VL,
11018+
DAG.getTargetConstant(
11019+
RISCVVType::TAIL_AGNOSTIC | RISCVVType::MASK_AGNOSTIC, DL, XLenVT),
11020+
DAG.getTargetConstant(Log2_64(VT.getScalarSizeInBits()), DL, XLenVT)};
11021+
// Insert the stride operand.
11022+
if (IsStrided)
11023+
Ops.insert(std::next(Ops.begin(), 4), Op.getOperand(3));
11024+
11025+
SDValue Result =
11026+
DAG.getMemIntrinsicNode(ISD::INTRINSIC_W_CHAIN, DL, VTs, Ops,
11027+
Load->getMemoryVT(), Load->getMemOperand());
11028+
SmallVector<SDValue, 9> Results;
11029+
for (unsigned int RetIdx = 0; RetIdx < NF; RetIdx++) {
11030+
SDValue SubVec = DAG.getNode(RISCVISD::TUPLE_EXTRACT, DL, ContainerVT,
11031+
Result.getValue(0),
11032+
DAG.getTargetConstant(RetIdx, DL, MVT::i32));
11033+
Results.push_back(convertFromScalableVector(VT, SubVec, DAG, Subtarget));
11034+
}
11035+
Results.push_back(Result.getValue(1));
11036+
return DAG.getMergeValues(Results, DL);
11037+
}
11038+
1094111039
SDValue RISCVTargetLowering::LowerINTRINSIC_W_CHAIN(SDValue Op,
1094211040
SelectionDAG &DAG) const {
1094311041
unsigned IntNo = Op.getConstantOperandVal(1);
@@ -10950,57 +11048,16 @@ SDValue RISCVTargetLowering::LowerINTRINSIC_W_CHAIN(SDValue Op,
1095011048
case Intrinsic::riscv_seg5_load_mask:
1095111049
case Intrinsic::riscv_seg6_load_mask:
1095211050
case Intrinsic::riscv_seg7_load_mask:
10953-
case Intrinsic::riscv_seg8_load_mask: {
10954-
SDLoc DL(Op);
10955-
static const Intrinsic::ID VlsegInts[7] = {
10956-
Intrinsic::riscv_vlseg2_mask, Intrinsic::riscv_vlseg3_mask,
10957-
Intrinsic::riscv_vlseg4_mask, Intrinsic::riscv_vlseg5_mask,
10958-
Intrinsic::riscv_vlseg6_mask, Intrinsic::riscv_vlseg7_mask,
10959-
Intrinsic::riscv_vlseg8_mask};
10960-
unsigned NF = Op->getNumValues() - 1;
10961-
assert(NF >= 2 && NF <= 8 && "Unexpected seg number");
10962-
MVT XLenVT = Subtarget.getXLenVT();
10963-
MVT VT = Op->getSimpleValueType(0);
10964-
MVT ContainerVT = getContainerForFixedLengthVector(VT);
10965-
unsigned Sz = NF * ContainerVT.getVectorMinNumElements() *
10966-
ContainerVT.getScalarSizeInBits();
10967-
EVT VecTupTy = MVT::getRISCVVectorTupleVT(Sz, NF);
10968-
10969-
// Operands: (chain, int_id, pointer, mask, vl)
10970-
SDValue VL = Op.getOperand(Op.getNumOperands() - 1);
10971-
SDValue Mask = Op.getOperand(3);
10972-
MVT MaskVT = Mask.getSimpleValueType();
10973-
MVT MaskContainerVT =
10974-
::getContainerForFixedLengthVector(DAG, MaskVT, Subtarget);
10975-
Mask = convertToScalableVector(MaskContainerVT, Mask, DAG, Subtarget);
10976-
10977-
SDValue IntID = DAG.getTargetConstant(VlsegInts[NF - 2], DL, XLenVT);
10978-
auto *Load = cast<MemIntrinsicSDNode>(Op);
11051+
case Intrinsic::riscv_seg8_load_mask:
11052+
case Intrinsic::riscv_sseg2_load_mask:
11053+
case Intrinsic::riscv_sseg3_load_mask:
11054+
case Intrinsic::riscv_sseg4_load_mask:
11055+
case Intrinsic::riscv_sseg5_load_mask:
11056+
case Intrinsic::riscv_sseg6_load_mask:
11057+
case Intrinsic::riscv_sseg7_load_mask:
11058+
case Intrinsic::riscv_sseg8_load_mask:
11059+
return lowerFixedVectorSegLoadIntrinsics(IntNo, Op, Subtarget, DAG);
1097911060

10980-
SDVTList VTs = DAG.getVTList({VecTupTy, MVT::Other});
10981-
SDValue Ops[] = {
10982-
Load->getChain(),
10983-
IntID,
10984-
DAG.getUNDEF(VecTupTy),
10985-
Op.getOperand(2),
10986-
Mask,
10987-
VL,
10988-
DAG.getTargetConstant(
10989-
RISCVVType::TAIL_AGNOSTIC | RISCVVType::MASK_AGNOSTIC, DL, XLenVT),
10990-
DAG.getTargetConstant(Log2_64(VT.getScalarSizeInBits()), DL, XLenVT)};
10991-
SDValue Result =
10992-
DAG.getMemIntrinsicNode(ISD::INTRINSIC_W_CHAIN, DL, VTs, Ops,
10993-
Load->getMemoryVT(), Load->getMemOperand());
10994-
SmallVector<SDValue, 9> Results;
10995-
for (unsigned int RetIdx = 0; RetIdx < NF; RetIdx++) {
10996-
SDValue SubVec = DAG.getNode(RISCVISD::TUPLE_EXTRACT, DL, ContainerVT,
10997-
Result.getValue(0),
10998-
DAG.getTargetConstant(RetIdx, DL, MVT::i32));
10999-
Results.push_back(convertFromScalableVector(VT, SubVec, DAG, Subtarget));
11000-
}
11001-
Results.push_back(Result.getValue(1));
11002-
return DAG.getMergeValues(Results, DL);
11003-
}
1100411061
case Intrinsic::riscv_sf_vc_v_x_se:
1100511062
return getVCIXISDNodeWCHAIN(Op, DAG, RISCVISD::SF_VC_V_X_SE);
1100611063
case Intrinsic::riscv_sf_vc_v_i_se:
Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py
2+
; RUN: llc -mtriple riscv64 -mattr=+zve64x,+zvl128b < %s | FileCheck %s
3+
4+
define {<8 x i8>, <8 x i8>} @load_factor2(ptr %ptr, i64 %stride) {
5+
; CHECK-LABEL: load_factor2:
6+
; CHECK: # %bb.0:
7+
; CHECK-NEXT: vsetivli zero, 8, e8, mf2, ta, ma
8+
; CHECK-NEXT: vlsseg2e8.v v8, (a0), a1
9+
; CHECK-NEXT: ret
10+
%1 = call { <8 x i8>, <8 x i8> } @llvm.riscv.sseg2.load.mask.v8i8.i64.i64(ptr %ptr, i64 %stride, <8 x i1> splat (i1 true), i64 8)
11+
ret {<8 x i8>, <8 x i8>} %1
12+
}
13+
14+
define {<8 x i8>, <8 x i8>, <8 x i8>} @load_factor3(ptr %ptr, i64 %stride) {
15+
; CHECK-LABEL: load_factor3:
16+
; CHECK: # %bb.0:
17+
; CHECK-NEXT: vsetivli zero, 8, e8, mf2, ta, ma
18+
; CHECK-NEXT: vlsseg3e8.v v8, (a0), a1
19+
; CHECK-NEXT: ret
20+
%1 = call { <8 x i8>, <8 x i8>, <8 x i8> } @llvm.riscv.sseg3.load.mask.v8i8.i64.i64(ptr %ptr, i64 %stride, <8 x i1> splat (i1 true), i64 8)
21+
ret { <8 x i8>, <8 x i8>, <8 x i8> } %1
22+
}
23+
24+
define {<8 x i8>, <8 x i8>, <8 x i8>, <8 x i8>} @load_factor4(ptr %ptr, i64 %stride) {
25+
; CHECK-LABEL: load_factor4:
26+
; CHECK: # %bb.0:
27+
; CHECK-NEXT: vsetivli zero, 8, e8, mf2, ta, ma
28+
; CHECK-NEXT: vlsseg4e8.v v8, (a0), a1
29+
; CHECK-NEXT: ret
30+
%1 = call { <8 x i8>, <8 x i8>, <8 x i8>, <8 x i8> } @llvm.riscv.sseg4.load.mask.v8i8.i64.i64(ptr %ptr, i64 %stride, <8 x i1> splat (i1 true), i64 8)
31+
ret { <8 x i8>, <8 x i8>, <8 x i8>, <8 x i8> } %1
32+
}
33+
34+
define {<8 x i8>, <8 x i8>, <8 x i8>, <8 x i8>, <8 x i8>} @load_factor5(ptr %ptr, i64 %stride) {
35+
; CHECK-LABEL: load_factor5:
36+
; CHECK: # %bb.0:
37+
; CHECK-NEXT: vsetivli zero, 8, e8, mf2, ta, ma
38+
; CHECK-NEXT: vlsseg5e8.v v8, (a0), a1
39+
; CHECK-NEXT: ret
40+
%1 = call { <8 x i8>, <8 x i8>, <8 x i8>, <8 x i8>, <8 x i8> } @llvm.riscv.sseg5.load.mask.v8i8.i64.i64(ptr %ptr, i64 %stride, <8 x i1> splat (i1 true), i64 8)
41+
ret { <8 x i8>, <8 x i8>, <8 x i8>, <8 x i8>, <8 x i8> } %1
42+
}
43+
44+
define {<8 x i8>, <8 x i8>, <8 x i8>, <8 x i8>, <8 x i8>, <8 x i8>} @load_factor6(ptr %ptr, i64 %stride) {
45+
; CHECK-LABEL: load_factor6:
46+
; CHECK: # %bb.0:
47+
; CHECK-NEXT: vsetivli zero, 8, e8, mf2, ta, ma
48+
; CHECK-NEXT: vlsseg6e8.v v8, (a0), a1
49+
; CHECK-NEXT: ret
50+
%1 = call { <8 x i8>, <8 x i8>, <8 x i8>, <8 x i8>, <8 x i8>, <8 x i8> } @llvm.riscv.sseg6.load.mask.v8i8.i64.i64(ptr %ptr, i64 %stride, <8 x i1> splat (i1 true), i64 8)
51+
ret { <8 x i8>, <8 x i8>, <8 x i8>, <8 x i8>, <8 x i8>, <8 x i8> } %1
52+
}
53+
54+
define {<8 x i8>, <8 x i8>, <8 x i8>, <8 x i8>, <8 x i8>, <8 x i8>, <8 x i8>} @load_factor7(ptr %ptr, i64 %stride) {
55+
; CHECK-LABEL: load_factor7:
56+
; CHECK: # %bb.0:
57+
; CHECK-NEXT: vsetivli zero, 8, e8, mf2, ta, ma
58+
; CHECK-NEXT: vlsseg7e8.v v8, (a0), a1
59+
; CHECK-NEXT: ret
60+
%1 = call { <8 x i8>, <8 x i8>, <8 x i8>, <8 x i8>, <8 x i8>, <8 x i8>, <8 x i8> } @llvm.riscv.sseg7.load.mask.v8i8.i64.i64(ptr %ptr, i64 %stride, <8 x i1> splat (i1 true), i64 8)
61+
ret { <8 x i8>, <8 x i8>, <8 x i8>, <8 x i8>, <8 x i8>, <8 x i8>, <8 x i8> } %1
62+
}
63+
64+
define {<8 x i8>, <8 x i8>, <8 x i8>, <8 x i8>, <8 x i8>, <8 x i8>, <8 x i8>, <8 x i8>} @load_factor8(ptr %ptr, i64 %stride) {
65+
; CHECK-LABEL: load_factor8:
66+
; CHECK: # %bb.0:
67+
; CHECK-NEXT: vsetivli zero, 8, e8, mf2, ta, ma
68+
; CHECK-NEXT: vlsseg8e8.v v8, (a0), a1
69+
; CHECK-NEXT: ret
70+
%1 = call { <8 x i8>, <8 x i8>, <8 x i8>, <8 x i8>, <8 x i8>, <8 x i8>, <8 x i8>, <8 x i8> } @llvm.riscv.sseg8.load.mask.v8i8.i64.i64(ptr %ptr, i64 %stride, <8 x i1> splat (i1 true), i64 8)
71+
ret { <8 x i8>, <8 x i8>, <8 x i8>, <8 x i8>, <8 x i8>, <8 x i8>, <8 x i8>, <8 x i8> } %1
72+
}

0 commit comments

Comments
 (0)