-
Notifications
You must be signed in to change notification settings - Fork 14.7k
[RISCV] Add intrinsics for strided segment loads with fixed vectors #151611
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[RISCV] Add intrinsics for strided segment loads with fixed vectors #151611
Conversation
@llvm/pr-subscribers-llvm-ir @llvm/pr-subscribers-backend-risc-v Author: Min-Yih Hsu (mshockwave) ChangesThese intrinsics are the strided version of Full diff: https://github.com/llvm/llvm-project/pull/151611.diff 3 Files Affected:
diff --git a/llvm/include/llvm/IR/IntrinsicsRISCV.td b/llvm/include/llvm/IR/IntrinsicsRISCV.td
index e63a41f4f6764..99f975faeb85e 100644
--- a/llvm/include/llvm/IR/IntrinsicsRISCV.td
+++ b/llvm/include/llvm/IR/IntrinsicsRISCV.td
@@ -1717,6 +1717,16 @@ let TargetPrefix = "riscv" in {
llvm_anyint_ty],
[NoCapture<ArgIndex<0>>, IntrReadMem]>;
+ // Input: (pointer, offset, mask, vl)
+ def int_riscv_sseg # nf # _load_mask
+ : DefaultAttrsIntrinsic<!listconcat([llvm_anyvector_ty],
+ !listsplat(LLVMMatchType<0>,
+ !add(nf, -1))),
+ [llvm_anyptr_ty, llvm_anyint_ty,
+ LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>,
+ llvm_anyint_ty],
+ [NoCapture<ArgIndex<0>>, IntrReadMem]>;
+
// Input: (<stored values>..., pointer, mask, vl)
def int_riscv_seg # nf # _store_mask
: DefaultAttrsIntrinsic<[],
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index 43e4f8e469905..bd68a340afa55 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -1819,6 +1819,13 @@ bool RISCVTargetLowering::getTgtMemIntrinsic(IntrinsicInfo &Info,
case Intrinsic::riscv_seg6_load_mask:
case Intrinsic::riscv_seg7_load_mask:
case Intrinsic::riscv_seg8_load_mask:
+ case Intrinsic::riscv_sseg2_load_mask:
+ case Intrinsic::riscv_sseg3_load_mask:
+ case Intrinsic::riscv_sseg4_load_mask:
+ case Intrinsic::riscv_sseg5_load_mask:
+ case Intrinsic::riscv_sseg6_load_mask:
+ case Intrinsic::riscv_sseg7_load_mask:
+ case Intrinsic::riscv_sseg8_load_mask:
return SetRVVLoadStoreInfo(/*PtrOp*/ 0, /*IsStore*/ false,
/*IsUnitStrided*/ false, /*UsePtrVal*/ true);
case Intrinsic::riscv_seg2_store_mask:
@@ -10959,6 +10966,97 @@ static inline SDValue getVCIXISDNodeVOID(SDValue &Op, SelectionDAG &DAG,
return DAG.getNode(Type, SDLoc(Op), Op.getValueType(), Operands);
}
+static SDValue
+convertFixedVectorSegLoadIntrinsics(unsigned IntNo, SDValue Op,
+ const RISCVSubtarget &Subtarget,
+ SelectionDAG &DAG) {
+ bool IsStrided;
+ switch (IntNo) {
+ case Intrinsic::riscv_seg2_load_mask:
+ case Intrinsic::riscv_seg3_load_mask:
+ case Intrinsic::riscv_seg4_load_mask:
+ case Intrinsic::riscv_seg5_load_mask:
+ case Intrinsic::riscv_seg6_load_mask:
+ case Intrinsic::riscv_seg7_load_mask:
+ case Intrinsic::riscv_seg8_load_mask:
+ IsStrided = false;
+ break;
+ case Intrinsic::riscv_sseg2_load_mask:
+ case Intrinsic::riscv_sseg3_load_mask:
+ case Intrinsic::riscv_sseg4_load_mask:
+ case Intrinsic::riscv_sseg5_load_mask:
+ case Intrinsic::riscv_sseg6_load_mask:
+ case Intrinsic::riscv_sseg7_load_mask:
+ case Intrinsic::riscv_sseg8_load_mask:
+ IsStrided = true;
+ break;
+ default:
+ llvm_unreachable("unexpected intrinsic ID");
+ };
+
+ static const Intrinsic::ID VlsegInts[7] = {
+ Intrinsic::riscv_vlseg2_mask, Intrinsic::riscv_vlseg3_mask,
+ Intrinsic::riscv_vlseg4_mask, Intrinsic::riscv_vlseg5_mask,
+ Intrinsic::riscv_vlseg6_mask, Intrinsic::riscv_vlseg7_mask,
+ Intrinsic::riscv_vlseg8_mask};
+ static const Intrinsic::ID VlssegInts[7] = {
+ Intrinsic::riscv_vlsseg2_mask, Intrinsic::riscv_vlsseg3_mask,
+ Intrinsic::riscv_vlsseg4_mask, Intrinsic::riscv_vlsseg5_mask,
+ Intrinsic::riscv_vlsseg6_mask, Intrinsic::riscv_vlsseg7_mask,
+ Intrinsic::riscv_vlsseg8_mask};
+
+ SDLoc DL(Op);
+ unsigned NF = Op->getNumValues() - 1;
+ assert(NF >= 2 && NF <= 8 && "Unexpected seg number");
+ MVT XLenVT = Subtarget.getXLenVT();
+ MVT VT = Op->getSimpleValueType(0);
+ MVT ContainerVT = ::getContainerForFixedLengthVector(DAG, VT, Subtarget);
+ unsigned Sz = NF * ContainerVT.getVectorMinNumElements() *
+ ContainerVT.getScalarSizeInBits();
+ EVT VecTupTy = MVT::getRISCVVectorTupleVT(Sz, NF);
+
+ // Operands: (chain, int_id, pointer, mask, vl) or
+ // (chain, int_id, pointer, offset, mask, vl)
+ SDValue VL = Op.getOperand(Op.getNumOperands() - 1);
+ SDValue Mask = Op.getOperand(Op.getNumOperands() - 2);
+ MVT MaskVT = Mask.getSimpleValueType();
+ MVT MaskContainerVT =
+ ::getContainerForFixedLengthVector(DAG, MaskVT, Subtarget);
+ Mask = convertToScalableVector(MaskContainerVT, Mask, DAG, Subtarget);
+
+ SDValue IntID = DAG.getTargetConstant(
+ IsStrided ? VlssegInts[NF - 2] : VlsegInts[NF - 2], DL, XLenVT);
+ auto *Load = cast<MemIntrinsicSDNode>(Op);
+
+ SDVTList VTs = DAG.getVTList({VecTupTy, MVT::Other});
+ SmallVector<SDValue, 9> Ops = {
+ Load->getChain(),
+ IntID,
+ DAG.getUNDEF(VecTupTy),
+ Op.getOperand(2),
+ Mask,
+ VL,
+ DAG.getTargetConstant(
+ RISCVVType::TAIL_AGNOSTIC | RISCVVType::MASK_AGNOSTIC, DL, XLenVT),
+ DAG.getTargetConstant(Log2_64(VT.getScalarSizeInBits()), DL, XLenVT)};
+ // Insert the stride operand.
+ if (IsStrided)
+ Ops.insert(std::next(Ops.begin(), 4), Op.getOperand(3));
+
+ SDValue Result =
+ DAG.getMemIntrinsicNode(ISD::INTRINSIC_W_CHAIN, DL, VTs, Ops,
+ Load->getMemoryVT(), Load->getMemOperand());
+ SmallVector<SDValue, 9> Results;
+ for (unsigned int RetIdx = 0; RetIdx < NF; RetIdx++) {
+ SDValue SubVec = DAG.getNode(RISCVISD::TUPLE_EXTRACT, DL, ContainerVT,
+ Result.getValue(0),
+ DAG.getTargetConstant(RetIdx, DL, MVT::i32));
+ Results.push_back(convertFromScalableVector(VT, SubVec, DAG, Subtarget));
+ }
+ Results.push_back(Result.getValue(1));
+ return DAG.getMergeValues(Results, DL);
+}
+
SDValue RISCVTargetLowering::LowerINTRINSIC_W_CHAIN(SDValue Op,
SelectionDAG &DAG) const {
unsigned IntNo = Op.getConstantOperandVal(1);
@@ -10971,57 +11069,16 @@ SDValue RISCVTargetLowering::LowerINTRINSIC_W_CHAIN(SDValue Op,
case Intrinsic::riscv_seg5_load_mask:
case Intrinsic::riscv_seg6_load_mask:
case Intrinsic::riscv_seg7_load_mask:
- case Intrinsic::riscv_seg8_load_mask: {
- SDLoc DL(Op);
- static const Intrinsic::ID VlsegInts[7] = {
- Intrinsic::riscv_vlseg2_mask, Intrinsic::riscv_vlseg3_mask,
- Intrinsic::riscv_vlseg4_mask, Intrinsic::riscv_vlseg5_mask,
- Intrinsic::riscv_vlseg6_mask, Intrinsic::riscv_vlseg7_mask,
- Intrinsic::riscv_vlseg8_mask};
- unsigned NF = Op->getNumValues() - 1;
- assert(NF >= 2 && NF <= 8 && "Unexpected seg number");
- MVT XLenVT = Subtarget.getXLenVT();
- MVT VT = Op->getSimpleValueType(0);
- MVT ContainerVT = getContainerForFixedLengthVector(VT);
- unsigned Sz = NF * ContainerVT.getVectorMinNumElements() *
- ContainerVT.getScalarSizeInBits();
- EVT VecTupTy = MVT::getRISCVVectorTupleVT(Sz, NF);
-
- // Operands: (chain, int_id, pointer, mask, vl)
- SDValue VL = Op.getOperand(Op.getNumOperands() - 1);
- SDValue Mask = Op.getOperand(3);
- MVT MaskVT = Mask.getSimpleValueType();
- MVT MaskContainerVT =
- ::getContainerForFixedLengthVector(DAG, MaskVT, Subtarget);
- Mask = convertToScalableVector(MaskContainerVT, Mask, DAG, Subtarget);
-
- SDValue IntID = DAG.getTargetConstant(VlsegInts[NF - 2], DL, XLenVT);
- auto *Load = cast<MemIntrinsicSDNode>(Op);
+ case Intrinsic::riscv_seg8_load_mask:
+ case Intrinsic::riscv_sseg2_load_mask:
+ case Intrinsic::riscv_sseg3_load_mask:
+ case Intrinsic::riscv_sseg4_load_mask:
+ case Intrinsic::riscv_sseg5_load_mask:
+ case Intrinsic::riscv_sseg6_load_mask:
+ case Intrinsic::riscv_sseg7_load_mask:
+ case Intrinsic::riscv_sseg8_load_mask:
+ return convertFixedVectorSegLoadIntrinsics(IntNo, Op, Subtarget, DAG);
- SDVTList VTs = DAG.getVTList({VecTupTy, MVT::Other});
- SDValue Ops[] = {
- Load->getChain(),
- IntID,
- DAG.getUNDEF(VecTupTy),
- Op.getOperand(2),
- Mask,
- VL,
- DAG.getTargetConstant(
- RISCVVType::TAIL_AGNOSTIC | RISCVVType::MASK_AGNOSTIC, DL, XLenVT),
- DAG.getTargetConstant(Log2_64(VT.getScalarSizeInBits()), DL, XLenVT)};
- SDValue Result =
- DAG.getMemIntrinsicNode(ISD::INTRINSIC_W_CHAIN, DL, VTs, Ops,
- Load->getMemoryVT(), Load->getMemOperand());
- SmallVector<SDValue, 9> Results;
- for (unsigned int RetIdx = 0; RetIdx < NF; RetIdx++) {
- SDValue SubVec = DAG.getNode(RISCVISD::TUPLE_EXTRACT, DL, ContainerVT,
- Result.getValue(0),
- DAG.getTargetConstant(RetIdx, DL, MVT::i32));
- Results.push_back(convertFromScalableVector(VT, SubVec, DAG, Subtarget));
- }
- Results.push_back(Result.getValue(1));
- return DAG.getMergeValues(Results, DL);
- }
case Intrinsic::riscv_sf_vc_v_x_se:
return getVCIXISDNodeWCHAIN(Op, DAG, RISCVISD::SF_VC_V_X_SE);
case Intrinsic::riscv_sf_vc_v_i_se:
diff --git a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-ssegN-load.ll b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-ssegN-load.ll
new file mode 100644
index 0000000000000..dd63fa0a2473b
--- /dev/null
+++ b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-ssegN-load.ll
@@ -0,0 +1,72 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py
+; RUN: llc -mtriple riscv64 -mattr=+zve64x,+zvl128b < %s | FileCheck %s
+
+define {<8 x i8>, <8 x i8>} @load_factor2(ptr %ptr, i64 %stride) {
+; CHECK-LABEL: load_factor2:
+; CHECK: # %bb.0:
+; CHECK-NEXT: vsetivli zero, 8, e8, mf2, ta, ma
+; CHECK-NEXT: vlsseg2e8.v v8, (a0), a1
+; CHECK-NEXT: ret
+ %1 = call { <8 x i8>, <8 x i8> } @llvm.riscv.sseg2.load.mask.v8i8.i64.i64(ptr %ptr, i64 %stride, <8 x i1> splat (i1 true), i64 8)
+ ret {<8 x i8>, <8 x i8>} %1
+}
+
+define {<8 x i8>, <8 x i8>, <8 x i8>} @load_factor3(ptr %ptr, i64 %stride) {
+; CHECK-LABEL: load_factor3:
+; CHECK: # %bb.0:
+; CHECK-NEXT: vsetivli zero, 8, e8, mf2, ta, ma
+; CHECK-NEXT: vlsseg3e8.v v8, (a0), a1
+; CHECK-NEXT: ret
+ %1 = call { <8 x i8>, <8 x i8>, <8 x i8> } @llvm.riscv.sseg3.load.mask.v8i8.i64.i64(ptr %ptr, i64 %stride, <8 x i1> splat (i1 true), i64 8)
+ ret { <8 x i8>, <8 x i8>, <8 x i8> } %1
+}
+
+define {<8 x i8>, <8 x i8>, <8 x i8>, <8 x i8>} @load_factor4(ptr %ptr, i64 %stride) {
+; CHECK-LABEL: load_factor4:
+; CHECK: # %bb.0:
+; CHECK-NEXT: vsetivli zero, 8, e8, mf2, ta, ma
+; CHECK-NEXT: vlsseg4e8.v v8, (a0), a1
+; CHECK-NEXT: ret
+ %1 = call { <8 x i8>, <8 x i8>, <8 x i8>, <8 x i8> } @llvm.riscv.sseg4.load.mask.v8i8.i64.i64(ptr %ptr, i64 %stride, <8 x i1> splat (i1 true), i64 8)
+ ret { <8 x i8>, <8 x i8>, <8 x i8>, <8 x i8> } %1
+}
+
+define {<8 x i8>, <8 x i8>, <8 x i8>, <8 x i8>, <8 x i8>} @load_factor5(ptr %ptr, i64 %stride) {
+; CHECK-LABEL: load_factor5:
+; CHECK: # %bb.0:
+; CHECK-NEXT: vsetivli zero, 8, e8, mf2, ta, ma
+; CHECK-NEXT: vlsseg5e8.v v8, (a0), a1
+; CHECK-NEXT: ret
+ %1 = call { <8 x i8>, <8 x i8>, <8 x i8>, <8 x i8>, <8 x i8> } @llvm.riscv.sseg5.load.mask.v8i8.i64.i64(ptr %ptr, i64 %stride, <8 x i1> splat (i1 true), i64 8)
+ ret { <8 x i8>, <8 x i8>, <8 x i8>, <8 x i8>, <8 x i8> } %1
+}
+
+define {<8 x i8>, <8 x i8>, <8 x i8>, <8 x i8>, <8 x i8>, <8 x i8>} @load_factor6(ptr %ptr, i64 %stride) {
+; CHECK-LABEL: load_factor6:
+; CHECK: # %bb.0:
+; CHECK-NEXT: vsetivli zero, 8, e8, mf2, ta, ma
+; CHECK-NEXT: vlsseg6e8.v v8, (a0), a1
+; CHECK-NEXT: ret
+ %1 = call { <8 x i8>, <8 x i8>, <8 x i8>, <8 x i8>, <8 x i8>, <8 x i8> } @llvm.riscv.sseg6.load.mask.v8i8.i64.i64(ptr %ptr, i64 %stride, <8 x i1> splat (i1 true), i64 8)
+ ret { <8 x i8>, <8 x i8>, <8 x i8>, <8 x i8>, <8 x i8>, <8 x i8> } %1
+}
+
+define {<8 x i8>, <8 x i8>, <8 x i8>, <8 x i8>, <8 x i8>, <8 x i8>, <8 x i8>} @load_factor7(ptr %ptr, i64 %stride) {
+; CHECK-LABEL: load_factor7:
+; CHECK: # %bb.0:
+; CHECK-NEXT: vsetivli zero, 8, e8, mf2, ta, ma
+; CHECK-NEXT: vlsseg7e8.v v8, (a0), a1
+; CHECK-NEXT: ret
+ %1 = call { <8 x i8>, <8 x i8>, <8 x i8>, <8 x i8>, <8 x i8>, <8 x i8>, <8 x i8> } @llvm.riscv.sseg7.load.mask.v8i8.i64.i64(ptr %ptr, i64 %stride, <8 x i1> splat (i1 true), i64 8)
+ ret { <8 x i8>, <8 x i8>, <8 x i8>, <8 x i8>, <8 x i8>, <8 x i8>, <8 x i8> } %1
+}
+
+define {<8 x i8>, <8 x i8>, <8 x i8>, <8 x i8>, <8 x i8>, <8 x i8>, <8 x i8>, <8 x i8>} @load_factor8(ptr %ptr, i64 %stride) {
+; CHECK-LABEL: load_factor8:
+; CHECK: # %bb.0:
+; CHECK-NEXT: vsetivli zero, 8, e8, mf2, ta, ma
+; CHECK-NEXT: vlsseg8e8.v v8, (a0), a1
+; CHECK-NEXT: ret
+ %1 = call { <8 x i8>, <8 x i8>, <8 x i8>, <8 x i8>, <8 x i8>, <8 x i8>, <8 x i8>, <8 x i8> } @llvm.riscv.sseg8.load.mask.v8i8.i64.i64(ptr %ptr, i64 %stride, <8 x i1> splat (i1 true), i64 8)
+ ret { <8 x i8>, <8 x i8>, <8 x i8>, <8 x i8>, <8 x i8>, <8 x i8>, <8 x i8>, <8 x i8> } %1
+}
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thanks
@@ -10959,6 +10966,97 @@ static inline SDValue getVCIXISDNodeVOID(SDValue &Op, SelectionDAG &DAG, | |||
return DAG.getNode(Type, SDLoc(Op), Op.getValueType(), Operands); | |||
} | |||
|
|||
static SDValue | |||
convertFixedVectorSegLoadIntrinsics(unsigned IntNo, SDValue Op, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit, I would call this something more like lowerFixedVectorSegLoadIntrinsics
if its only being called from LowerINTRINSIC_W_CHAIN
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
These intrinsics are the strided version of
llvm.riscv.segN.load
intrinsics.This PR is meant to support #151612