Skip to content

Commit f1b5989

Browse files
Muzammiluddin-Syed-ECESeongjaeP
authored andcommitted
[mlir][ArithToAMDGPU] limit scaling truncf/extf support to gfx950 (llvm#155431)
The current chip guard fails to prevent scaling_extf/truncf patterns from being applied on gfx1100 which does not have scaling support. --------- Signed-off-by: Muzammiluddin Syed <[email protected]>
1 parent e177c9e commit f1b5989

File tree

4 files changed

+17
-10
lines changed

4 files changed

+17
-10
lines changed

mlir/include/mlir/Conversion/ArithToAMDGPU/ArithToAMDGPU.h

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -28,12 +28,10 @@ namespace arith {
2828
/// is set, values outside the range of the destination type are clamped
2929
/// to the largest value of that type instead of being rewritten to Inf (aka
3030
/// NaN).
31-
void populateArithToAMDGPUConversionPatterns(RewritePatternSet &patterns,
32-
bool convertFP8Arithmetic,
33-
bool saturateFP8Truncf,
34-
bool allowPackedF16Rtz,
35-
amdgpu::Chipset chipset,
36-
PatternBenefit benefit = 1);
31+
void populateArithToAMDGPUConversionPatterns(
32+
RewritePatternSet &patterns, bool convertFP8Arithmetic,
33+
bool saturateFP8Truncf, bool allowPackedF16Rtz, bool supportsScaledExtTrunc,
34+
amdgpu::Chipset chipset, PatternBenefit benefit = 1);
3735
} // namespace arith
3836
} // namespace mlir
3937

mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -690,8 +690,8 @@ ScalingTruncFRewritePattern::matchAndRewrite(arith::ScalingTruncFOp op,
690690

691691
void mlir::arith::populateArithToAMDGPUConversionPatterns(
692692
RewritePatternSet &patterns, bool convertFP8Arithmetic,
693-
bool saturateFP8Truncf, bool allowPackedF16Rtz, Chipset chipset,
694-
PatternBenefit benefit) {
693+
bool saturateFP8Truncf, bool allowPackedF16Rtz, bool supportsScaledExtTrunc,
694+
Chipset chipset, PatternBenefit benefit) {
695695

696696
if (convertFP8Arithmetic) {
697697
patterns.add<ExtFOnFloat8RewritePattern>(patterns.getContext(), chipset,
@@ -702,7 +702,7 @@ void mlir::arith::populateArithToAMDGPUConversionPatterns(
702702
if (allowPackedF16Rtz)
703703
patterns.add<TruncfToFloat16RewritePattern>(patterns.getContext(), benefit);
704704

705-
if (chipset >= kGfx950) {
705+
if (supportsScaledExtTrunc) {
706706
patterns.add<ScalingExtFRewritePattern>(patterns.getContext(), benefit);
707707
patterns.add<ScalingTruncFRewritePattern>(patterns.getContext(), benefit);
708708
}
@@ -720,9 +720,10 @@ void ArithToAMDGPUConversionPass::runOnOperation() {
720720

721721
bool convertFP8Arithmetic =
722722
*maybeChipset == kGfx942 || hasOcpFp8(*maybeChipset);
723+
bool supportsScaledExtTrunc = *maybeChipset == kGfx950;
723724
arith::populateArithToAMDGPUConversionPatterns(
724725
patterns, convertFP8Arithmetic, saturateFP8Truncf, allowPackedF16Rtz,
725-
*maybeChipset);
726+
supportsScaledExtTrunc, *maybeChipset);
726727
if (failed(applyPatternsGreedily(op, std::move(patterns))))
727728
return signalPassFailure();
728729
}

mlir/test/Conversion/ArithToAMDGPU/scaling-extf.mlir

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
// RUN: mlir-opt --split-input-file %s -convert-arith-to-amdgpu="chipset=gfx950" | FileCheck %s
2+
// RUN: mlir-opt --split-input-file %s -convert-arith-to-amdgpu="chipset=gfx1100" | FileCheck %s --check-prefix=CHECK-GFX1100
23

34
// CHECK-LABEL: @conversion_f8_f32_fallback
45
// CHECK: %[[CST:.+]] = arith.constant dense<0.000000e+00> : vector<2x2xf32>
@@ -241,6 +242,9 @@ func.func @conversion_broadcast(%in: vector<4xf8E5M2>, %scale: f8E8M0FNU) -> vec
241242

242243
// -----
243244

245+
// CHECK-GFX1100-LABEL: @conversion_scalar
246+
// CHECK-GFX1100: arith.scaling_extf
247+
244248
// CHECK-LABEL: @conversion_scalar
245249
// CHECK: %[[SCALE_F32:.+]] = arith.extf %arg1 : f8E8M0FNU to f32
246250
// CHECK-NEXT: %[[SPLAT_IN:.+]] = vector.broadcast %arg0 : f8E5M2 to vector<1xf8E5M2>

mlir/test/Conversion/ArithToAMDGPU/scaling-truncf.mlir

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
// RUN: mlir-opt --split-input-file %s -convert-arith-to-amdgpu="chipset=gfx950" | FileCheck %s
2+
// RUN: mlir-opt --split-input-file %s -convert-arith-to-amdgpu="chipset=gfx1100" | FileCheck %s --check-prefix=CHECK-GFX1100
23

34
// CHECK-LABEL: @conversion_f8_fallback
45
// CHECK-DAG: %[[CST:.+]] = arith.constant dense<0.000000e+00> : vector<2x2xf8E5M2>
@@ -163,6 +164,9 @@ func.func @conversion_broadcast(%in: vector<4xf32>, %scale: f8E8M0FNU) -> vector
163164

164165
// -----
165166

167+
// CHECK-GFX1100-LABEL: @conversion_scalar
168+
// CHECK-GFX1100: arith.scaling_truncf
169+
166170
// CHECK-LABEL: @conversion_scalar
167171
// CHECK: %[[SCALE_F32:.+]] = arith.extf %arg1 : f8E8M0FNU to f32
168172
// CHECK-NEXT: %[[SPLAT_IN:.+]] = vector.broadcast %arg0 : f32 to vector<1xf32>

0 commit comments

Comments
 (0)