@@ -690,8 +690,8 @@ ScalingTruncFRewritePattern::matchAndRewrite(arith::ScalingTruncFOp op,
690
690
691
691
void mlir::arith::populateArithToAMDGPUConversionPatterns (
692
692
RewritePatternSet &patterns, bool convertFP8Arithmetic,
693
- bool saturateFP8Truncf, bool allowPackedF16Rtz, Chipset chipset ,
694
- PatternBenefit benefit) {
693
+ bool saturateFP8Truncf, bool allowPackedF16Rtz, bool supportsScaledExtTrunc ,
694
+ Chipset chipset, PatternBenefit benefit) {
695
695
696
696
if (convertFP8Arithmetic) {
697
697
patterns.add <ExtFOnFloat8RewritePattern>(patterns.getContext (), chipset,
@@ -702,7 +702,7 @@ void mlir::arith::populateArithToAMDGPUConversionPatterns(
702
702
if (allowPackedF16Rtz)
703
703
patterns.add <TruncfToFloat16RewritePattern>(patterns.getContext (), benefit);
704
704
705
- if (chipset == kGfx950 ) {
705
+ if (supportsScaledExtTrunc ) {
706
706
patterns.add <ScalingExtFRewritePattern>(patterns.getContext (), benefit);
707
707
patterns.add <ScalingTruncFRewritePattern>(patterns.getContext (), benefit);
708
708
}
@@ -720,9 +720,10 @@ void ArithToAMDGPUConversionPass::runOnOperation() {
720
720
721
721
bool convertFP8Arithmetic =
722
722
*maybeChipset == kGfx942 || hasOcpFp8 (*maybeChipset);
723
+ bool supportsScaledExtTrunc = *maybeChipset == kGfx950 ;
723
724
arith::populateArithToAMDGPUConversionPatterns (
724
725
patterns, convertFP8Arithmetic, saturateFP8Truncf, allowPackedF16Rtz,
725
- *maybeChipset);
726
+ supportsScaledExtTrunc, *maybeChipset);
726
727
if (failed (applyPatternsGreedily (op, std::move (patterns))))
727
728
return signalPassFailure ();
728
729
}
0 commit comments