Skip to content

Commit c065ed3

Browse files
authored
[RISCV] Add intrinsics for strided segment stores with fixed vectors (#152038)
These are the strided versions of `riscv.segN.store.mask` intrinsics.
1 parent 10088b6 commit c065ed3

File tree

3 files changed

+193
-50
lines changed

3 files changed

+193
-50
lines changed

llvm/include/llvm/IR/IntrinsicsRISCV.td

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1736,6 +1736,17 @@ let TargetPrefix = "riscv" in {
17361736
[llvm_anyptr_ty, LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>,
17371737
llvm_anyint_ty]),
17381738
[NoCapture<ArgIndex<nf>>, IntrWriteMem]>;
1739+
1740+
// Input: (<stored values>..., pointer, stride, mask, vl)
1741+
def int_riscv_sseg # nf # _store_mask
1742+
: DefaultAttrsIntrinsic<[],
1743+
!listconcat([llvm_anyvector_ty],
1744+
!listsplat(LLVMMatchType<0>,
1745+
!add(nf, -1)),
1746+
[llvm_anyptr_ty, llvm_anyint_ty,
1747+
LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>,
1748+
llvm_anyint_ty]),
1749+
[NoCapture<ArgIndex<nf>>, IntrWriteMem]>;
17391750
}
17401751

17411752
} // TargetPrefix = "riscv"

llvm/lib/Target/RISCV/RISCVISelLowering.cpp

Lines changed: 110 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -1844,6 +1844,17 @@ bool RISCVTargetLowering::getTgtMemIntrinsic(IntrinsicInfo &Info,
18441844
return SetRVVLoadStoreInfo(/*PtrOp*/ I.arg_size() - 3,
18451845
/*IsStore*/ true,
18461846
/*IsUnitStrided*/ false, /*UsePtrVal*/ true);
1847+
case Intrinsic::riscv_sseg2_store_mask:
1848+
case Intrinsic::riscv_sseg3_store_mask:
1849+
case Intrinsic::riscv_sseg4_store_mask:
1850+
case Intrinsic::riscv_sseg5_store_mask:
1851+
case Intrinsic::riscv_sseg6_store_mask:
1852+
case Intrinsic::riscv_sseg7_store_mask:
1853+
case Intrinsic::riscv_sseg8_store_mask:
1854+
// Operands are (vec, ..., vec, ptr, offset, mask, vl)
1855+
return SetRVVLoadStoreInfo(/*PtrOp*/ I.arg_size() - 4,
1856+
/*IsStore*/ true,
1857+
/*IsUnitStrided*/ false, /*UsePtrVal*/ true);
18471858
case Intrinsic::riscv_vlm:
18481859
return SetRVVLoadStoreInfo(/*PtrOp*/ 0,
18491860
/*IsStore*/ false,
@@ -11084,69 +11095,118 @@ SDValue RISCVTargetLowering::LowerINTRINSIC_W_CHAIN(SDValue Op,
1108411095
return lowerVectorIntrinsicScalars(Op, DAG, Subtarget);
1108511096
}
1108611097

11087-
SDValue RISCVTargetLowering::LowerINTRINSIC_VOID(SDValue Op,
11088-
SelectionDAG &DAG) const {
11089-
unsigned IntNo = Op.getConstantOperandVal(1);
11098+
static SDValue
11099+
lowerFixedVectorSegStoreIntrinsics(unsigned IntNo, SDValue Op,
11100+
const RISCVSubtarget &Subtarget,
11101+
SelectionDAG &DAG) {
11102+
bool IsStrided;
1109011103
switch (IntNo) {
11091-
default:
11092-
break;
1109311104
case Intrinsic::riscv_seg2_store_mask:
1109411105
case Intrinsic::riscv_seg3_store_mask:
1109511106
case Intrinsic::riscv_seg4_store_mask:
1109611107
case Intrinsic::riscv_seg5_store_mask:
1109711108
case Intrinsic::riscv_seg6_store_mask:
1109811109
case Intrinsic::riscv_seg7_store_mask:
11099-
case Intrinsic::riscv_seg8_store_mask: {
11100-
SDLoc DL(Op);
11101-
static const Intrinsic::ID VssegInts[] = {
11102-
Intrinsic::riscv_vsseg2_mask, Intrinsic::riscv_vsseg3_mask,
11103-
Intrinsic::riscv_vsseg4_mask, Intrinsic::riscv_vsseg5_mask,
11104-
Intrinsic::riscv_vsseg6_mask, Intrinsic::riscv_vsseg7_mask,
11105-
Intrinsic::riscv_vsseg8_mask};
11110+
case Intrinsic::riscv_seg8_store_mask:
11111+
IsStrided = false;
11112+
break;
11113+
case Intrinsic::riscv_sseg2_store_mask:
11114+
case Intrinsic::riscv_sseg3_store_mask:
11115+
case Intrinsic::riscv_sseg4_store_mask:
11116+
case Intrinsic::riscv_sseg5_store_mask:
11117+
case Intrinsic::riscv_sseg6_store_mask:
11118+
case Intrinsic::riscv_sseg7_store_mask:
11119+
case Intrinsic::riscv_sseg8_store_mask:
11120+
IsStrided = true;
11121+
break;
11122+
default:
11123+
llvm_unreachable("unexpected intrinsic ID");
11124+
}
1110611125

11107-
// Operands: (chain, int_id, vec*, ptr, mask, vl)
11108-
unsigned NF = Op->getNumOperands() - 5;
11109-
assert(NF >= 2 && NF <= 8 && "Unexpected seg number");
11110-
MVT XLenVT = Subtarget.getXLenVT();
11111-
MVT VT = Op->getOperand(2).getSimpleValueType();
11112-
MVT ContainerVT = getContainerForFixedLengthVector(VT);
11113-
unsigned Sz = NF * ContainerVT.getVectorMinNumElements() *
11114-
ContainerVT.getScalarSizeInBits();
11115-
EVT VecTupTy = MVT::getRISCVVectorTupleVT(Sz, NF);
11126+
SDLoc DL(Op);
11127+
static const Intrinsic::ID VssegInts[] = {
11128+
Intrinsic::riscv_vsseg2_mask, Intrinsic::riscv_vsseg3_mask,
11129+
Intrinsic::riscv_vsseg4_mask, Intrinsic::riscv_vsseg5_mask,
11130+
Intrinsic::riscv_vsseg6_mask, Intrinsic::riscv_vsseg7_mask,
11131+
Intrinsic::riscv_vsseg8_mask};
11132+
static const Intrinsic::ID VsssegInts[] = {
11133+
Intrinsic::riscv_vssseg2_mask, Intrinsic::riscv_vssseg3_mask,
11134+
Intrinsic::riscv_vssseg4_mask, Intrinsic::riscv_vssseg5_mask,
11135+
Intrinsic::riscv_vssseg6_mask, Intrinsic::riscv_vssseg7_mask,
11136+
Intrinsic::riscv_vssseg8_mask};
11137+
11138+
// Operands: (chain, int_id, vec*, ptr, mask, vl) or
11139+
// (chain, int_id, vec*, ptr, stride, mask, vl)
11140+
unsigned NF = Op->getNumOperands() - (IsStrided ? 6 : 5);
11141+
assert(NF >= 2 && NF <= 8 && "Unexpected seg number");
11142+
MVT XLenVT = Subtarget.getXLenVT();
11143+
MVT VT = Op->getOperand(2).getSimpleValueType();
11144+
MVT ContainerVT = ::getContainerForFixedLengthVector(DAG, VT, Subtarget);
11145+
unsigned Sz = NF * ContainerVT.getVectorMinNumElements() *
11146+
ContainerVT.getScalarSizeInBits();
11147+
EVT VecTupTy = MVT::getRISCVVectorTupleVT(Sz, NF);
1111611148

11117-
SDValue VL = Op.getOperand(Op.getNumOperands() - 1);
11118-
SDValue Mask = Op.getOperand(Op.getNumOperands() - 2);
11119-
MVT MaskVT = Mask.getSimpleValueType();
11120-
MVT MaskContainerVT =
11121-
::getContainerForFixedLengthVector(DAG, MaskVT, Subtarget);
11122-
Mask = convertToScalableVector(MaskContainerVT, Mask, DAG, Subtarget);
11149+
SDValue VL = Op.getOperand(Op.getNumOperands() - 1);
11150+
SDValue Mask = Op.getOperand(Op.getNumOperands() - 2);
11151+
MVT MaskVT = Mask.getSimpleValueType();
11152+
MVT MaskContainerVT =
11153+
::getContainerForFixedLengthVector(DAG, MaskVT, Subtarget);
11154+
Mask = convertToScalableVector(MaskContainerVT, Mask, DAG, Subtarget);
1112311155

11124-
SDValue IntID = DAG.getTargetConstant(VssegInts[NF - 2], DL, XLenVT);
11125-
SDValue Ptr = Op->getOperand(NF + 2);
11156+
SDValue IntID = DAG.getTargetConstant(
11157+
IsStrided ? VsssegInts[NF - 2] : VssegInts[NF - 2], DL, XLenVT);
11158+
SDValue Ptr = Op->getOperand(NF + 2);
1112611159

11127-
auto *FixedIntrinsic = cast<MemIntrinsicSDNode>(Op);
11160+
auto *FixedIntrinsic = cast<MemIntrinsicSDNode>(Op);
1112811161

11129-
SDValue StoredVal = DAG.getUNDEF(VecTupTy);
11130-
for (unsigned i = 0; i < NF; i++)
11131-
StoredVal = DAG.getNode(
11132-
RISCVISD::TUPLE_INSERT, DL, VecTupTy, StoredVal,
11133-
convertToScalableVector(
11134-
ContainerVT, FixedIntrinsic->getOperand(2 + i), DAG, Subtarget),
11135-
DAG.getTargetConstant(i, DL, MVT::i32));
11162+
SDValue StoredVal = DAG.getUNDEF(VecTupTy);
11163+
for (unsigned i = 0; i < NF; i++)
11164+
StoredVal = DAG.getNode(
11165+
RISCVISD::TUPLE_INSERT, DL, VecTupTy, StoredVal,
11166+
convertToScalableVector(ContainerVT, FixedIntrinsic->getOperand(2 + i),
11167+
DAG, Subtarget),
11168+
DAG.getTargetConstant(i, DL, MVT::i32));
11169+
11170+
SmallVector<SDValue, 10> Ops = {
11171+
FixedIntrinsic->getChain(),
11172+
IntID,
11173+
StoredVal,
11174+
Ptr,
11175+
Mask,
11176+
VL,
11177+
DAG.getTargetConstant(Log2_64(VT.getScalarSizeInBits()), DL, XLenVT)};
11178+
// Insert the stride operand.
11179+
if (IsStrided)
11180+
Ops.insert(std::next(Ops.begin(), 4),
11181+
Op.getOperand(Op.getNumOperands() - 3));
11182+
11183+
return DAG.getMemIntrinsicNode(
11184+
ISD::INTRINSIC_VOID, DL, DAG.getVTList(MVT::Other), Ops,
11185+
FixedIntrinsic->getMemoryVT(), FixedIntrinsic->getMemOperand());
11186+
}
11187+
11188+
SDValue RISCVTargetLowering::LowerINTRINSIC_VOID(SDValue Op,
11189+
SelectionDAG &DAG) const {
11190+
unsigned IntNo = Op.getConstantOperandVal(1);
11191+
switch (IntNo) {
11192+
default:
11193+
break;
11194+
case Intrinsic::riscv_seg2_store_mask:
11195+
case Intrinsic::riscv_seg3_store_mask:
11196+
case Intrinsic::riscv_seg4_store_mask:
11197+
case Intrinsic::riscv_seg5_store_mask:
11198+
case Intrinsic::riscv_seg6_store_mask:
11199+
case Intrinsic::riscv_seg7_store_mask:
11200+
case Intrinsic::riscv_seg8_store_mask:
11201+
case Intrinsic::riscv_sseg2_store_mask:
11202+
case Intrinsic::riscv_sseg3_store_mask:
11203+
case Intrinsic::riscv_sseg4_store_mask:
11204+
case Intrinsic::riscv_sseg5_store_mask:
11205+
case Intrinsic::riscv_sseg6_store_mask:
11206+
case Intrinsic::riscv_sseg7_store_mask:
11207+
case Intrinsic::riscv_sseg8_store_mask:
11208+
return lowerFixedVectorSegStoreIntrinsics(IntNo, Op, Subtarget, DAG);
1113611209

11137-
SDValue Ops[] = {
11138-
FixedIntrinsic->getChain(),
11139-
IntID,
11140-
StoredVal,
11141-
Ptr,
11142-
Mask,
11143-
VL,
11144-
DAG.getTargetConstant(Log2_64(VT.getScalarSizeInBits()), DL, XLenVT)};
11145-
11146-
return DAG.getMemIntrinsicNode(
11147-
ISD::INTRINSIC_VOID, DL, DAG.getVTList(MVT::Other), Ops,
11148-
FixedIntrinsic->getMemoryVT(), FixedIntrinsic->getMemOperand());
11149-
}
1115011210
case Intrinsic::riscv_sf_vc_xv_se:
1115111211
return getVCIXISDNodeVOID(Op, DAG, RISCVISD::SF_VC_XV_SE);
1115211212
case Intrinsic::riscv_sf_vc_iv_se:
Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
2+
; RUN: llc -mtriple=riscv64 -mattr=+v -verify-machineinstrs < %s | FileCheck %s
3+
4+
define void @store_factor2(<8 x i8> %v0, <8 x i8> %v1, ptr %ptr, i64 %stride) {
5+
; CHECK-LABEL: store_factor2:
6+
; CHECK: # %bb.0:
7+
; CHECK-NEXT: vsetivli zero, 8, e8, mf2, ta, ma
8+
; CHECK-NEXT: vssseg2e8.v v8, (a0), a1
9+
; CHECK-NEXT: ret
10+
call void @llvm.riscv.sseg2.store.mask.v8i8.i64.i64(<8 x i8> %v0, <8 x i8> %v1, ptr %ptr, i64 %stride, <8 x i1> splat (i1 true), i64 8)
11+
ret void
12+
}
13+
14+
define void @store_factor3(<8 x i8> %v0, <8 x i8> %v1, <8 x i8> %v2, ptr %ptr, i64 %stride) {
15+
; CHECK-LABEL: store_factor3:
16+
; CHECK: # %bb.0:
17+
; CHECK-NEXT: vsetivli zero, 8, e8, mf2, ta, ma
18+
; CHECK-NEXT: vssseg3e8.v v8, (a0), a1
19+
; CHECK-NEXT: ret
20+
call void @llvm.riscv.sseg3.store.mask.v8i8.i64.i64(<8 x i8> %v0, <8 x i8> %v1, <8 x i8> %v2, ptr %ptr, i64 %stride, <8 x i1> splat (i1 true), i64 8)
21+
ret void
22+
}
23+
24+
define void @store_factor4(<8 x i8> %v0, <8 x i8> %v1, <8 x i8> %v2, <8 x i8> %v3, ptr %ptr, i64 %stride) {
25+
; CHECK-LABEL: store_factor4:
26+
; CHECK: # %bb.0:
27+
; CHECK-NEXT: vsetivli zero, 8, e8, mf2, ta, ma
28+
; CHECK-NEXT: vssseg4e8.v v8, (a0), a1
29+
; CHECK-NEXT: ret
30+
call void @llvm.riscv.sseg4.store.mask.v8i8.i64.i64(<8 x i8> %v0, <8 x i8> %v1, <8 x i8> %v2, <8 x i8> %v3, ptr %ptr, i64 %stride, <8 x i1> splat (i1 true), i64 8)
31+
ret void
32+
}
33+
34+
define void @store_factor5(<8 x i8> %v0, <8 x i8> %v1, <8 x i8> %v2, <8 x i8> %v3, <8 x i8> %v4, ptr %ptr, i64 %stride) {
35+
; CHECK-LABEL: store_factor5:
36+
; CHECK: # %bb.0:
37+
; CHECK-NEXT: vsetivli zero, 8, e8, mf2, ta, ma
38+
; CHECK-NEXT: vssseg5e8.v v8, (a0), a1
39+
; CHECK-NEXT: ret
40+
call void @llvm.riscv.sseg5.store.mask.v8i8.i64.i64(<8 x i8> %v0, <8 x i8> %v1, <8 x i8> %v2, <8 x i8> %v3, <8 x i8> %v4, ptr %ptr, i64 %stride, <8 x i1> splat (i1 true), i64 8)
41+
ret void
42+
}
43+
44+
define void @store_factor6(<8 x i8> %v0, <8 x i8> %v1, <8 x i8> %v2, <8 x i8> %v3, <8 x i8> %v4, <8 x i8> %v5, ptr %ptr, i64 %stride) {
45+
; CHECK-LABEL: store_factor6:
46+
; CHECK: # %bb.0:
47+
; CHECK-NEXT: vsetivli zero, 8, e8, mf2, ta, ma
48+
; CHECK-NEXT: vssseg6e8.v v8, (a0), a1
49+
; CHECK-NEXT: ret
50+
call void @llvm.riscv.sseg6.store.mask.v8i8.i64.i64(<8 x i8> %v0, <8 x i8> %v1, <8 x i8> %v2, <8 x i8> %v3, <8 x i8> %v4, <8 x i8> %v5, ptr %ptr, i64 %stride, <8 x i1> splat (i1 true), i64 8)
51+
ret void
52+
}
53+
54+
define void @store_factor7(<8 x i8> %v0, <8 x i8> %v1, <8 x i8> %v2, <8 x i8> %v3, <8 x i8> %v4, <8 x i8> %v5, <8 x i8> %v6, ptr %ptr, i64 %stride) {
55+
; CHECK-LABEL: store_factor7:
56+
; CHECK: # %bb.0:
57+
; CHECK-NEXT: vsetivli zero, 8, e8, mf2, ta, ma
58+
; CHECK-NEXT: vssseg7e8.v v8, (a0), a1
59+
; CHECK-NEXT: ret
60+
call void @llvm.riscv.sseg7.store.mask.v8i8.i64.i64(<8 x i8> %v0, <8 x i8> %v1, <8 x i8> %v2, <8 x i8> %v3, <8 x i8> %v4, <8 x i8> %v5, <8 x i8> %v6, ptr %ptr, i64 %stride, <8 x i1> splat (i1 true), i64 8)
61+
ret void
62+
}
63+
64+
define void @store_factor8(<8 x i8> %v0, <8 x i8> %v1, <8 x i8> %v2, <8 x i8> %v3, <8 x i8> %v4, <8 x i8> %v5, <8 x i8> %v6, <8 x i8> %v7, ptr %ptr, i64 %stride) {
65+
; CHECK-LABEL: store_factor8:
66+
; CHECK: # %bb.0:
67+
; CHECK-NEXT: vsetivli zero, 8, e8, mf2, ta, ma
68+
; CHECK-NEXT: vssseg8e8.v v8, (a0), a1
69+
; CHECK-NEXT: ret
70+
call void @llvm.riscv.sseg8.store.mask.v8i8.i64.i64(<8 x i8> %v0, <8 x i8> %v1, <8 x i8> %v2, <8 x i8> %v3, <8 x i8> %v4, <8 x i8> %v5, <8 x i8> %v6, <8 x i8> %v7, ptr %ptr, i64 %stride, <8 x i1> splat (i1 true), i64 8)
71+
ret void
72+
}

0 commit comments

Comments
 (0)