Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 4 additions & 6 deletions mlir/include/mlir/Conversion/ArithToAMDGPU/ArithToAMDGPU.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,10 @@ namespace arith {
/// is set, values outside the range of the destination type are clamped
/// to the largest value of that type instead of being rewritten to Inf (aka
/// NaN).
void populateArithToAMDGPUConversionPatterns(RewritePatternSet &patterns,
bool convertFP8Arithmetic,
bool saturateFP8Truncf,
bool allowPackedF16Rtz,
amdgpu::Chipset chipset,
PatternBenefit benefit = 1);
void populateArithToAMDGPUConversionPatterns(
RewritePatternSet &patterns, bool convertFP8Arithmetic,
bool saturateFP8Truncf, bool allowPackedF16Rtz, bool supportsScaledExtTrunc,
amdgpu::Chipset chipset, PatternBenefit benefit = 1);
} // namespace arith
} // namespace mlir

Expand Down
9 changes: 5 additions & 4 deletions mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -690,8 +690,8 @@ ScalingTruncFRewritePattern::matchAndRewrite(arith::ScalingTruncFOp op,

void mlir::arith::populateArithToAMDGPUConversionPatterns(
RewritePatternSet &patterns, bool convertFP8Arithmetic,
bool saturateFP8Truncf, bool allowPackedF16Rtz, Chipset chipset,
PatternBenefit benefit) {
bool saturateFP8Truncf, bool allowPackedF16Rtz, bool supportsScaledExtTrunc,
Chipset chipset, PatternBenefit benefit) {

if (convertFP8Arithmetic) {
patterns.add<ExtFOnFloat8RewritePattern>(patterns.getContext(), chipset,
Expand All @@ -702,7 +702,7 @@ void mlir::arith::populateArithToAMDGPUConversionPatterns(
if (allowPackedF16Rtz)
patterns.add<TruncfToFloat16RewritePattern>(patterns.getContext(), benefit);

if (chipset >= kGfx950) {
if (supportsScaledExtTrunc) {
patterns.add<ScalingExtFRewritePattern>(patterns.getContext(), benefit);
patterns.add<ScalingTruncFRewritePattern>(patterns.getContext(), benefit);
}
Expand All @@ -720,9 +720,10 @@ void ArithToAMDGPUConversionPass::runOnOperation() {

bool convertFP8Arithmetic =
*maybeChipset == kGfx942 || hasOcpFp8(*maybeChipset);
bool supportsScaledExtTrunc = *maybeChipset == kGfx950;
arith::populateArithToAMDGPUConversionPatterns(
patterns, convertFP8Arithmetic, saturateFP8Truncf, allowPackedF16Rtz,
*maybeChipset);
supportsScaledExtTrunc, *maybeChipset);
if (failed(applyPatternsGreedily(op, std::move(patterns))))
return signalPassFailure();
}
4 changes: 4 additions & 0 deletions mlir/test/Conversion/ArithToAMDGPU/scaling-extf.mlir
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
// RUN: mlir-opt --split-input-file %s -convert-arith-to-amdgpu="chipset=gfx950" | FileCheck %s
// RUN: mlir-opt --split-input-file %s -convert-arith-to-amdgpu="chipset=gfx1100" | FileCheck %s --check-prefix=CHECK-GFX1100

// CHECK-LABEL: @conversion_f8_f32_fallback
// CHECK: %[[CST:.+]] = arith.constant dense<0.000000e+00> : vector<2x2xf32>
Expand Down Expand Up @@ -241,6 +242,9 @@ func.func @conversion_broadcast(%in: vector<4xf8E5M2>, %scale: f8E8M0FNU) -> vec

// -----

// CHECK-GFX1100-LABEL: @conversion_scalar
// CHECK-GFX1100: arith.scaling_extf
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't understand the operations involved here. Why does this need target knowledge, to emit something that isn't a target operation? Why isn't there target legalization?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There does not exist support for scaled mxfp4 types in this target (gfx1100) so there is no viable legalization IIUC (cc @krzysz00 )

The main reason this change is necessary is to avoid running the scaled extf/truncf rewrite on unsupported chips when calling populateArithToAMDGPUConversionPatterns (See here).

However, this check does not need to exist in ArithToAMDGPU.cpp. If preferred, I can create a separate function populateConversionPatterns in the same vein as how it's done in Math/Transforms where we use pass in a list of ops which we are interested in expanding, moving the check to the caller rather than having it here.

/// Adds patterns to expand math operations into other more fundamental
/// operations. For example, hyperbolic functions are expanded into expressions
/// using `exp`. If `opMnemonics` is empty then all available patterns will be
/// added, otherwise only the patterns corresponding to ops in `opMnemonics`
/// will be added to the set.
void populateExpansionPatterns(RewritePatternSet &patterns,
                               ArrayRef<StringRef> opMnemonics = {});

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So, these rewrite patterns are specifically for "lower arith operations to target intrinsics if they exist". If the target intrinsic doesn't exist, these patterns shouldn't run, and later patterns (the "generic" expansion over in ExpandArithOps) will run instead.

So, the main point of the test here is to ensure that we don't emit intrinsics calls that can't be fulfilled, or rewrite things into a more complex form that only makes sense if you're targetting intrinsics (the amdgpu.* operations are intrinsic wrappers that proved somewhat higher-level APIs)

I think these target checks are fine right where they are, especially since this lowering already has to do a decent number of target checks (for example, which FP8 formats get lowered to intrinsic calls)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

arith.scaling_extf isn't a target-specific operation, but, if it has an implementation on a given chipset, this pass lowers to that target-specific implementation. Hence the test checking that, on gfx1100, this pass is a noop (because other passes introduce a less efficient lowering)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(If there were an easy, low-dependency way to query the CPU features - that is, if we substantially refactored LLVM - these target checks could be checks on the same flags LLVM uses to test for the ultimate presence of the intrinsics. But that isn't, so chipset checks it is.)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.


// CHECK-LABEL: @conversion_scalar
// CHECK: %[[SCALE_F32:.+]] = arith.extf %arg1 : f8E8M0FNU to f32
// CHECK-NEXT: %[[SPLAT_IN:.+]] = vector.broadcast %arg0 : f8E5M2 to vector<1xf8E5M2>
Expand Down
4 changes: 4 additions & 0 deletions mlir/test/Conversion/ArithToAMDGPU/scaling-truncf.mlir
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
// RUN: mlir-opt --split-input-file %s -convert-arith-to-amdgpu="chipset=gfx950" | FileCheck %s
// RUN: mlir-opt --split-input-file %s -convert-arith-to-amdgpu="chipset=gfx1100" | FileCheck %s --check-prefix=CHECK-GFX1100

// CHECK-LABEL: @conversion_f8_fallback
// CHECK-DAG: %[[CST:.+]] = arith.constant dense<0.000000e+00> : vector<2x2xf8E5M2>
Expand Down Expand Up @@ -163,6 +164,9 @@ func.func @conversion_broadcast(%in: vector<4xf32>, %scale: f8E8M0FNU) -> vector

// -----

// CHECK-GFX1100-LABEL: @conversion_scalar
// CHECK-GFX1100: arith.scaling_truncf

// CHECK-LABEL: @conversion_scalar
// CHECK: %[[SCALE_F32:.+]] = arith.extf %arg1 : f8E8M0FNU to f32
// CHECK-NEXT: %[[SPLAT_IN:.+]] = vector.broadcast %arg0 : f32 to vector<1xf32>
Expand Down