Skip to content

[AMDGPU] Optimize rotate/funnel shift pattern matching in instruction selection #149817

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
169 changes: 169 additions & 0 deletions llvm/lib/Target/AMDGPU/AMDGPUISelDAGToDAG.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -820,6 +820,13 @@ void AMDGPUDAGToDAGISel::Select(SDNode *N) {
SelectSTACKRESTORE(N);
return;
}
case ISD::OR: {
if (SDNode *Selected = selectRotateOrFunnelShiftPattern(N)) {
ReplaceNode(N, Selected);
return;
}
break;
}
}

SelectCode(N);
Expand Down Expand Up @@ -4105,6 +4112,168 @@ void AMDGPUDAGToDAGISel::PostprocessISelDAG() {
} while (IsModified);
}

// Pattern matching for rotate/funnel shift operations
// and converts them to v_alignbit_b32 instructions
Comment on lines +4115 to +4116
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why does this need manual selection? Several of these cases look like they belong in the funnel shift combines or could be moved to patterns

SDNode *AMDGPUDAGToDAGISel::selectRotateOrFunnelShiftPattern(SDNode *N) {
if (N->getOpcode() != ISD::OR)
return nullptr;

// Only handle 32-bit operations
if (N->getValueType(0) != MVT::i32)
return nullptr;

if (!N->isDivergent())
return nullptr;

SDValue LHS = N->getOperand(0);
SDValue RHS = N->getOperand(1);

SDNode *ShlNode = nullptr;
SDNode *SrlNode = nullptr;

// Check both orderings: (shl, srl) and (srl, shl)
bool IsLHSShl = LHS.getOpcode() == ISD::SHL;
bool IsRHSSrl = RHS.getOpcode() == ISD::SRL;
bool IsLHSSrl = LHS.getOpcode() == ISD::SRL;
bool IsRHSShl = RHS.getOpcode() == ISD::SHL;

if ((IsLHSShl && IsRHSSrl) || (IsLHSSrl && IsRHSShl)) {
ShlNode = IsLHSShl ? LHS.getNode() : RHS.getNode();
SrlNode = IsRHSSrl ? RHS.getNode() : LHS.getNode();
} else {
return nullptr;
}

// Extract sources and shift amounts
SDValue ShlSrc = ShlNode->getOperand(0);
SDValue ShlAmt = ShlNode->getOperand(1);
SDValue SrlSrc = SrlNode->getOperand(0);
SDValue SrlAmt = SrlNode->getOperand(1);

// Handle the legalizer's (src << 1) pattern for SHL source
if (ShlSrc.getOpcode() == ISD::SHL)
if (ConstantSDNode *PreShlAmt =
dyn_cast<ConstantSDNode>(ShlSrc.getOperand(1)))
if (PreShlAmt->getZExtValue() == 1)
ShlSrc = ShlSrc.getOperand(0);

// Helper function to build AlignBit instruction
auto buildAlignBitInstruction = [&](SDValue AlignBitSrc0,
SDValue AlignBitSrc1,
SDValue ShiftAmount) -> SDNode * {
SDLoc DL(N);

// Select opcode based on subtarget features
const GCNSubtarget &ST = CurDAG->getSubtarget<GCNSubtarget>();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think there's already a member for this

unsigned Opcode =
ST.getGeneration() >= AMDGPUSubtarget::GFX11
? (ST.useRealTrue16Insts() ? AMDGPU::V_ALIGNBIT_B32_t16_e64
: AMDGPU::V_ALIGNBIT_B32_fake16_e64)
: ST.hasTrue16BitInsts()
? (ST.useRealTrue16Insts() ? AMDGPU::V_ALIGNBIT_B32_t16_e64
: AMDGPU::V_ALIGNBIT_B32_fake16_e64)
: AMDGPU::V_ALIGNBIT_B32_e64;

SDValue Ops[8]; // Maximum operands needed
unsigned NumOps = 0;

if (Opcode == AMDGPU::V_ALIGNBIT_B32_t16_e64 ||
Opcode == AMDGPU::V_ALIGNBIT_B32_fake16_e64) {
// Extended format with modifiers
Ops[0] = CurDAG->getTargetConstant(0, DL, MVT::i32); // src0_modifiers
Ops[1] = AlignBitSrc0; // src0
Ops[2] = CurDAG->getTargetConstant(0, DL, MVT::i32); // src1_modifiers
Ops[3] = AlignBitSrc1; // src1
Ops[4] = CurDAG->getTargetConstant(0, DL, MVT::i32); // src2_modifiers
Ops[5] = ShiftAmount; // src2
Ops[6] = CurDAG->getTargetConstant(0, DL, MVT::i32); // clamp
Ops[7] = CurDAG->getTargetConstant(0, DL, MVT::i32); // op_sel
NumOps = 8;
} else {
// Regular e64 format
Ops[0] = AlignBitSrc0;
Ops[1] = AlignBitSrc1;
Ops[2] = ShiftAmount;
NumOps = 3;
}

return CurDAG->getMachineNode(Opcode, DL, MVT::i32,
ArrayRef<SDValue>(Ops, NumOps));
};

// Case 1: Both shift amounts are constants
ConstantSDNode *ShlConstant = dyn_cast<ConstantSDNode>(ShlAmt);
ConstantSDNode *SrlConstant = dyn_cast<ConstantSDNode>(SrlAmt);

if (ShlConstant && SrlConstant) {
int64_t ShlVal = ShlConstant->getSExtValue();
int64_t SrlVal = SrlConstant->getSExtValue();

if (ShlVal + SrlVal != 32)
return nullptr;

// Create constant for shift amount
SDLoc DL(N);
SDValue ConstAmtNode = CurDAG->getTargetConstant(SrlVal, DL, MVT::i32);

return buildAlignBitInstruction(ShlSrc, SrlSrc, ConstAmtNode);
}

// Helper to extract shift amount from (some_value & 31) pattern
auto getShiftAmount = [&](SDValue ShiftAmtVal) -> SDValue {
if (ShiftAmtVal.getOpcode() == ISD::AND)
if (ConstantSDNode *MaskNode =
dyn_cast<ConstantSDNode>(ShiftAmtVal.getOperand(1)))
if (MaskNode->getZExtValue() == 31)
return ShiftAmtVal.getOperand(0);

return SDValue();
};

// Case 2: Variable shift amounts - check the AND pattern
SDValue ShlAmtSrc = getShiftAmount(ShlAmt);
SDValue SrlAmtSrc = getShiftAmount(SrlAmt);

if (!ShlAmtSrc || !SrlAmtSrc)
return nullptr;

// Check if SHL amount comes from NOT or NEG of the original amount
SDValue OriginalAmt;
bool IsRotatePattern = false;

if (ShlAmtSrc.getOpcode() == ISD::XOR) {
// FSHR pattern: SHL amount = (~original_amt) & 31
if (ConstantSDNode *XorMask =
dyn_cast<ConstantSDNode>(ShlAmtSrc.getOperand(1))) {
if (XorMask->getSExtValue() == -1) {
if (ShlAmtSrc.getOperand(0) == SrlAmtSrc) {
OriginalAmt = SrlAmtSrc;
IsRotatePattern = false;
}
}
}
} else if (ShlAmtSrc.getOpcode() == ISD::SUB) {
// ROTR pattern: SHL amount = (-original_amt) & 31 = (0 - original_amt) & 31
if (ConstantSDNode *SubLHS =
dyn_cast<ConstantSDNode>(ShlAmtSrc.getOperand(0))) {
if (SubLHS->getZExtValue() == 0) {
if (ShlAmtSrc.getOperand(1) == SrlAmtSrc) {
OriginalAmt = SrlAmtSrc;
IsRotatePattern = true;
}
}
}
}

if (!OriginalAmt)
return nullptr;

SDValue AlignBitSrc0 = ShlSrc;
SDValue AlignBitSrc1 = IsRotatePattern ? ShlSrc : SrlSrc;

return buildAlignBitInstruction(AlignBitSrc0, AlignBitSrc1, OriginalAmt);
}

AMDGPUDAGToDAGISelLegacy::AMDGPUDAGToDAGISelLegacy(TargetMachine &TM,
CodeGenOptLevel OptLevel)
: SelectionDAGISelLegacy(
Expand Down
1 change: 1 addition & 0 deletions llvm/lib/Target/AMDGPU/AMDGPUISelDAGToDAG.h
Original file line number Diff line number Diff line change
Expand Up @@ -284,6 +284,7 @@ class AMDGPUDAGToDAGISel : public SelectionDAGISel {
void SelectINTRINSIC_VOID(SDNode *N);
void SelectWAVE_ADDRESS(SDNode *N);
void SelectSTACKRESTORE(SDNode *N);
SDNode *selectRotateOrFunnelShiftPattern(SDNode *N);

protected:
// Include the pieces autogenerated from the target description.
Expand Down
10 changes: 7 additions & 3 deletions llvm/lib/Target/AMDGPU/AMDGPUISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -486,12 +486,16 @@ AMDGPUTargetLowering::AMDGPUTargetLowering(const TargetMachine &TM,
setOperationAction({ISD::ADDC, ISD::SUBC, ISD::ADDE, ISD::SUBE}, VT, Legal);
}

// The hardware supports 32-bit FSHR, but not FSHL.
setOperationAction(ISD::FSHR, MVT::i32, Legal);
if (Subtarget->isGCN()) {
setOperationAction(ISD::FSHR, MVT::i32, Expand);
setOperationAction(ISD::ROTR, {MVT::i32, MVT::i64}, Expand);
} else {
setOperationAction(ISD::FSHR, MVT::i32, Legal);
setOperationAction(ISD::ROTR, {MVT::i32, MVT::i64}, Legal);
}

// The hardware supports 32-bit ROTR, but not ROTL.
setOperationAction(ISD::ROTL, {MVT::i32, MVT::i64}, Expand);
setOperationAction(ISD::ROTR, MVT::i64, Expand);

setOperationAction({ISD::MULHU, ISD::MULHS}, MVT::i16, Expand);

Expand Down
Loading
Loading