Skip to content

Commit 2f2a400

Browse files
hoist chipset check to variable
Signed-off-by: Muzammiluddin Syed <[email protected]>
1 parent 5464787 commit 2f2a400

File tree

2 files changed

+9
-10
lines changed

2 files changed

+9
-10
lines changed

mlir/include/mlir/Conversion/ArithToAMDGPU/ArithToAMDGPU.h

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -28,12 +28,10 @@ namespace arith {
2828
/// is set, values outside the range of the destination type are clamped
2929
/// to the largest value of that type instead of being rewritten to Inf (aka
3030
/// NaN).
31-
void populateArithToAMDGPUConversionPatterns(RewritePatternSet &patterns,
32-
bool convertFP8Arithmetic,
33-
bool saturateFP8Truncf,
34-
bool allowPackedF16Rtz,
35-
amdgpu::Chipset chipset,
36-
PatternBenefit benefit = 1);
31+
void populateArithToAMDGPUConversionPatterns(
32+
RewritePatternSet &patterns, bool convertFP8Arithmetic,
33+
bool saturateFP8Truncf, bool allowPackedF16Rtz, bool supportsScaledExtTrunc,
34+
amdgpu::Chipset chipset, PatternBenefit benefit = 1);
3735
} // namespace arith
3836
} // namespace mlir
3937

mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -690,8 +690,8 @@ ScalingTruncFRewritePattern::matchAndRewrite(arith::ScalingTruncFOp op,
690690

691691
void mlir::arith::populateArithToAMDGPUConversionPatterns(
692692
RewritePatternSet &patterns, bool convertFP8Arithmetic,
693-
bool saturateFP8Truncf, bool allowPackedF16Rtz, Chipset chipset,
694-
PatternBenefit benefit) {
693+
bool saturateFP8Truncf, bool allowPackedF16Rtz, bool supportsScaledExtTrunc,
694+
Chipset chipset, PatternBenefit benefit) {
695695

696696
if (convertFP8Arithmetic) {
697697
patterns.add<ExtFOnFloat8RewritePattern>(patterns.getContext(), chipset,
@@ -702,7 +702,7 @@ void mlir::arith::populateArithToAMDGPUConversionPatterns(
702702
if (allowPackedF16Rtz)
703703
patterns.add<TruncfToFloat16RewritePattern>(patterns.getContext(), benefit);
704704

705-
if (chipset == kGfx950) {
705+
if (supportsScaledExtTrunc) {
706706
patterns.add<ScalingExtFRewritePattern>(patterns.getContext(), benefit);
707707
patterns.add<ScalingTruncFRewritePattern>(patterns.getContext(), benefit);
708708
}
@@ -720,9 +720,10 @@ void ArithToAMDGPUConversionPass::runOnOperation() {
720720

721721
bool convertFP8Arithmetic =
722722
*maybeChipset == kGfx942 || hasOcpFp8(*maybeChipset);
723+
bool supportsScaledExtTrunc = *maybeChipset == kGfx950;
723724
arith::populateArithToAMDGPUConversionPatterns(
724725
patterns, convertFP8Arithmetic, saturateFP8Truncf, allowPackedF16Rtz,
725-
*maybeChipset);
726+
supportsScaledExtTrunc, *maybeChipset);
726727
if (failed(applyPatternsGreedily(op, std::move(patterns))))
727728
return signalPassFailure();
728729
}

0 commit comments

Comments
 (0)