Skip to content

Conversation

matthiasdiener
Copy link

@matthiasdiener matthiasdiener commented Sep 26, 2025

This enables the float8_e5m2 part of the test for AMD's gfx950 devices,
which has gained support for this data type (see
https://rocm.docs.amd.com/en/latest/reference/precision-support.html).
This test was already enabled for CUDA devices.

New contributor declaration

  • I am not making a trivial change, such as fixing a typo in a comment.

  • I have written a PR description following these
    rules.

  • I have run pre-commit run --from-ref origin/main --to-ref HEAD.

  • Select one of the following.

    • I have added tests.
      • /test for lit tests
      • /unittest for C++ tests
      • /python/test for end-to-end tests
    • This PR does not need a test because tutorials/03-matrix-multiplication.py does not appear to be part of a test suite. I have tested this manually on a gfx950.
  • Select one of the following.

    • I have not added any lit tests.
    • The lit tests I have added follow these best practices,
      including the "tests should be minimal" section. (Usually running Python code
      and using the instructions it generates is not minimal.)

@matthiasdiener matthiasdiener force-pushed the enable-fp8-mm-tutorial-gfx950 branch from 9b19cd3 to 4595b65 Compare September 26, 2025 21:15
{'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 6},
]
return [triton.Config(s | {'matrix_instr_nonkdim': 16}, num_warps=8, num_stages=2) for s in sizes]
return [triton.Config(s | {'matrix_instr_nonkdim': 32}, num_warps=8, num_stages=2) for s in sizes]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why changing this? Seems unrelated to the purpose of this pull request?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the only configuration I found that works with both fp16 and fp8. Do you prefer if I split up the function into separate configs for fp16 and fp8?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For fp8 to use the mfma_16x16x128 instruction, BLOCK_K needs to be >=128. The current configs only have BLOCK_K=64, that's why it does not work.
I think performance is not a concern in the tutorials. We are lacking other optimizations, such as enabling buffer_load to avoid branches for masked load. I'd suggest to not touch the tuning config.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would you prefer that I change BLOCK_SIZE_K to 128, instead of the change to matrix_instr_nonkdim?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I prefer you leave all tuning config un-changed. We are in the middle of fixing issues related to buffer ops. The current perf from tutorial will be temporal anyways.

Copy link
Author

@matthiasdiener matthiasdiener Sep 28, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the clarification on keeping the tutorial configs stable.
Just to clarify why I touched this, it is not due to performance reasons: with fp8, the example crashes at compile time on gfx950 during ConvertTritonAMDGPUToLLVM when using the unmodified tuning config:

python: /root/.triton/llvm/llvm-064f02da-ubuntu-x64/include/llvm/ADT/SmallVector.h:292: T& llvm::SmallVectorTemplateCommon<T, <template-parameter-1-2> >::operator[](size_type) [with T = mlir::Value; <template-parameter-1-2> = void; reference = mlir::Value&; size_type = long unsigned int]: Assertion idx < size()' failed.
#blocked = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 16], warpsPerCTA = [8, 1], order = [1, 0]}>
#blocked1 = #ttg.blocked<{sizePerThread = [4, 1], threadsPerWarp = [16, 4], warpsPerCTA = [1, 8], order = [0, 1]}>
#mma = #ttg.amd_mfma<{version = 4, warpsPerCTA = [4, 2], instrShape = [16, 16, 128], isTransposed = true}>
#shared = #ttg.swizzled_shared<{vec = 16, perPhase = 4, maxPhase = 4, order = [1, 0]}>
#shared1 = #ttg.swizzled_shared<{vec = 16, perPhase = 4, maxPhase = 4, order = [0, 1]}>
#smem = #ttg.shared_memory
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.target = "hip:gfx950", "ttg.threads-per-warp" = 64 : i32} {
  tt.func public @matmul_kernel(%arg0: !tt.ptr<f8E5M2> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %arg1: !tt.ptr<f8E5M2> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %arg2: !tt.ptr<f16> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}, %arg4: i32 {tt.divisibility = 16 : i32}, %arg5: i32 {tt.divisibility = 16 : i32}, %arg6: i32 {tt.divisibility = 16 : i32}, %arg7: i32 {tt.divisibility = 16 : i32}, %arg8: i32 {tt.divisibility = 16 : i32}) attributes {noinline = false} {
    %cst = arith.constant dense<64> : tensor<32x64xi32, #blocked>
    %cst_0 = arith.constant dense<64> : tensor<64x32xi32, #blocked1>
    %c1_i32 = arith.constant 1 : i32
    %c31_i32 = arith.constant 31 : i32
    %c6_i32 = arith.constant 6 : i32
    %c0_i32 = arith.constant 0 : i32
    %true = arith.constant true
    %c32_i32 = arith.constant 32 : i32
    %c64_i32 = arith.constant 64 : i32
    %cst_1 = arith.constant dense<0.000000e+00> : tensor<64x32xf8E5M2, #blocked1>
    %cst_2 = arith.constant dense<0.000000e+00> : tensor<32x64xf8E5M2, #blocked>
    %c63_i32 = arith.constant 63 : i32
    %cst_3 = arith.constant dense<0.000000e+00> : tensor<32x32xf32, #mma>
    %0 = tt.get_program_id x : i32
    %1 = arith.addi %arg3, %c31_i32 : i32
    %2 = arith.divsi %1, %c32_i32 : i32
    %3 = arith.addi %arg4, %c31_i32 : i32
    %4 = arith.divsi %3, %c32_i32 : i32
    %5 = arith.muli %4, %c6_i32 : i32
    %6 = arith.divsi %0, %5 : i32
    %7 = arith.muli %6, %c6_i32 : i32
    %8 = arith.subi %2, %7 : i32
    %9 = arith.minsi %8, %c6_i32 : i32
    %10 = arith.remsi %0, %5 : i32
    %11 = arith.remsi %10, %9 : i32
    %12 = arith.addi %7, %11 : i32
    %13 = arith.divsi %10, %9 : i32
    llvm.intr.assume %true : i1
    llvm.intr.assume %true : i1
    llvm.intr.assume %true : i1
    llvm.intr.assume %true : i1
    llvm.intr.assume %true : i1
    llvm.intr.assume %true : i1
    llvm.intr.assume %true : i1
    llvm.intr.assume %true : i1
    %14 = arith.muli %12, %c32_i32 : i32
    %15 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #ttg.slice<{dim = 1, parent = #blocked}>>
    %16 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #ttg.slice<{dim = 1, parent = #mma}>>
    %17 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #ttg.slice<{dim = 0, parent = #blocked1}>>
    %18 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #ttg.slice<{dim = 0, parent = #mma}>>
    %19 = tt.splat %14 : i32 -> tensor<32xi32, #ttg.slice<{dim = 1, parent = #blocked}>>
    %20 = tt.splat %14 : i32 -> tensor<32xi32, #ttg.slice<{dim = 1, parent = #mma}>>
    %21 = arith.addi %19, %15 : tensor<32xi32, #ttg.slice<{dim = 1, parent = #blocked}>>
    %22 = arith.addi %20, %16 : tensor<32xi32, #ttg.slice<{dim = 1, parent = #mma}>>
    %23 = tt.splat %arg3 : i32 -> tensor<32xi32, #ttg.slice<{dim = 1, parent = #blocked}>>
    %24 = arith.remsi %21, %23 : tensor<32xi32, #ttg.slice<{dim = 1, parent = #blocked}>>
    %25 = arith.muli %13, %c32_i32 : i32
    %26 = tt.splat %25 : i32 -> tensor<32xi32, #ttg.slice<{dim = 0, parent = #blocked1}>>
    %27 = tt.splat %25 : i32 -> tensor<32xi32, #ttg.slice<{dim = 0, parent = #mma}>>
    %28 = arith.addi %26, %17 : tensor<32xi32, #ttg.slice<{dim = 0, parent = #blocked1}>>
    %29 = arith.addi %27, %18 : tensor<32xi32, #ttg.slice<{dim = 0, parent = #mma}>>
    %30 = tt.splat %arg4 : i32 -> tensor<32xi32, #ttg.slice<{dim = 0, parent = #blocked1}>>
    %31 = arith.remsi %28, %30 : tensor<32xi32, #ttg.slice<{dim = 0, parent = #blocked1}>>
    %32 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked}>>
    %33 = tt.expand_dims %32 {axis = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x64xi32, #blocked>
    %34 = tt.expand_dims %24 {axis = 1 : i32} : tensor<32xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<32x1xi32, #blocked>
    %35 = tt.splat %arg6 : i32 -> tensor<32x1xi32, #blocked>
    %36 = arith.muli %34, %35 : tensor<32x1xi32, #blocked>
    %37 = tt.broadcast %36 : tensor<32x1xi32, #blocked> -> tensor<32x64xi32, #blocked>
    %38 = tt.broadcast %33 : tensor<1x64xi32, #blocked> -> tensor<32x64xi32, #blocked>
    %39 = arith.addi %37, %38 : tensor<32x64xi32, #blocked>
    %40 = arith.addi %arg5, %c63_i32 : i32
    %41 = arith.divsi %40, %c64_i32 : i32
    %42 = arith.cmpi sgt, %41, %c0_i32 : i32
    %43 = tt.splat %42 : i1 -> tensor<32x64xi1, #blocked>
    %44 = tt.splat %arg5 : i32 -> tensor<1x64xi32, #blocked>
    %45 = arith.cmpi slt, %33, %44 : tensor<1x64xi32, #blocked>
    %46 = tt.broadcast %45 : tensor<1x64xi1, #blocked> -> tensor<32x64xi1, #blocked>
    %47 = arith.andi %43, %46 : tensor<32x64xi1, #blocked>
    %48 = amdgpu.buffer_load %arg0[%39], %47 stride = %arg6 : tensor<32x64xf8E5M2, #blocked>
    %49 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked1}>>
    %50 = tt.expand_dims %49 {axis = 1 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> -> tensor<64x1xi32, #blocked1>
    %51 = tt.broadcast %50 : tensor<64x1xi32, #blocked1> -> tensor<64x32xi32, #blocked1>
    %52 = tt.expand_dims %31 {axis = 0 : i32} : tensor<32xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x32xi32, #blocked1>
    %53 = tt.splat %arg7 : i32 -> tensor<1x32xi32, #blocked1>
    %54 = arith.muli %52, %53 : tensor<1x32xi32, #blocked1>
    %55 = tt.broadcast %54 : tensor<1x32xi32, #blocked1> -> tensor<64x32xi32, #blocked1>
    %56 = arith.addi %51, %55 : tensor<64x32xi32, #blocked1>
    %57 = tt.splat %42 : i1 -> tensor<64x32xi1, #blocked1>
    %58 = tt.splat %arg5 : i32 -> tensor<64x1xi32, #blocked1>
    %59 = arith.cmpi slt, %50, %58 : tensor<64x1xi32, #blocked1>
    %60 = tt.broadcast %59 : tensor<64x1xi1, #blocked1> -> tensor<64x32xi1, #blocked1>
    %61 = arith.andi %57, %60 : tensor<64x32xi1, #blocked1>
    %62 = amdgpu.buffer_load %arg1[%56], %61 stride = %arg7 : tensor<64x32xf8E5M2, #blocked1>
    %63 = ttg.local_alloc : () -> !ttg.memdesc<1x32x64xf8E5M2, #shared, #smem, mutable>
    %64 = ttg.local_alloc : () -> !ttg.memdesc<1x64x32xf8E5M2, #shared1, #smem, mutable>
    %65 = ttg.memdesc_index %63[%c0_i32] : !ttg.memdesc<1x32x64xf8E5M2, #shared, #smem, mutable> -> !ttg.memdesc<32x64xf8E5M2, #shared, #smem, mutable, 1x32x64>
    ttg.local_store %48, %65 : tensor<32x64xf8E5M2, #blocked> -> !ttg.memdesc<32x64xf8E5M2, #shared, #smem, mutable, 1x32x64>
    %66 = ttg.memdesc_index %64[%c0_i32] : !ttg.memdesc<1x64x32xf8E5M2, #shared1, #smem, mutable> -> !ttg.memdesc<64x32xf8E5M2, #shared1, #smem, mutable, 1x64x32>
    ttg.local_store %62, %66 : tensor<64x32xf8E5M2, #blocked1> -> !ttg.memdesc<64x32xf8E5M2, #shared1, #smem, mutable, 1x64x32>
    %67 = arith.subi %41, %c1_i32 : i32
    %68:6 = scf.for %arg9 = %c0_i32 to %67 step %c1_i32 iter_args(%arg10 = %cst_3, %arg11 = %39, %arg12 = %56, %arg13 = %c0_i32, %arg14 = %65, %arg15 = %66) -> (tensor<32x32xf32, #mma>, tensor<32x64xi32, #blocked>, tensor<64x32xi32, #blocked1>, i32, !ttg.memdesc<32x64xf8E5M2, #shared, #smem, mutable, 1x32x64>, !ttg.memdesc<64x32xf8E5M2, #shared1, #smem, mutable, 1x64x32>)  : i32 {
      %95 = arith.addi %arg11, %cst : tensor<32x64xi32, #blocked>
      %96 = arith.addi %arg12, %cst_0 : tensor<64x32xi32, #blocked1>
      %97 = arith.addi %arg9, %c1_i32 : i32
      %98 = arith.muli %97, %c64_i32 : i32
      %99 = arith.subi %arg5, %98 : i32
      %100 = tt.splat %99 : i32 -> tensor<1x64xi32, #blocked>
      %101 = arith.cmpi slt, %33, %100 : tensor<1x64xi32, #blocked>
      %102 = tt.broadcast %101 : tensor<1x64xi1, #blocked> -> tensor<32x64xi1, #blocked>
      %103 = tt.splat %arg0 : !tt.ptr<f8E5M2> -> tensor<32x64x!tt.ptr<f8E5M2>, #blocked>
      %104 = tt.addptr %103, %95 : tensor<32x64x!tt.ptr<f8E5M2>, #blocked>, tensor<32x64xi32, #blocked>
      %105 = tt.load %104, %102, %cst_2 : tensor<32x64x!tt.ptr<f8E5M2>, #blocked>
      %106 = ttg.local_load %arg14 : !ttg.memdesc<32x64xf8E5M2, #shared, #smem, mutable, 1x32x64> -> tensor<32x64xf8E5M2, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 16}>>
      %107 = tt.splat %99 : i32 -> tensor<64x1xi32, #blocked1>
      %108 = arith.cmpi slt, %50, %107 : tensor<64x1xi32, #blocked1>
      %109 = tt.broadcast %108 : tensor<64x1xi1, #blocked1> -> tensor<64x32xi1, #blocked1>
      %110 = tt.splat %arg1 : !tt.ptr<f8E5M2> -> tensor<64x32x!tt.ptr<f8E5M2>, #blocked1>
      %111 = tt.addptr %110, %96 : tensor<64x32x!tt.ptr<f8E5M2>, #blocked1>, tensor<64x32xi32, #blocked1>
      %112 = tt.load %111, %109, %cst_1 : tensor<64x32x!tt.ptr<f8E5M2>, #blocked1>
      %113 = ttg.local_load %arg15 : !ttg.memdesc<64x32xf8E5M2, #shared1, #smem, mutable, 1x64x32> -> tensor<64x32xf8E5M2, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 16}>>
      %114 = tt.dot_scaled %106, %113, %arg10 lhs = e5m2 rhs = e5m2 {fastMath = false} : tensor<32x64xf8E5M2, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 16}>> * tensor<64x32xf8E5M2, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 16}>> -> tensor<32x32xf32, #mma>
      %115 = arith.addi %arg13, %c1_i32 : i32
      %116 = arith.cmpi slt, %115, %c1_i32 : i32
      %117 = arith.select %116, %115, %c0_i32 : i32
      %118 = ttg.memdesc_index %63[%117] : !ttg.memdesc<1x32x64xf8E5M2, #shared, #smem, mutable> -> !ttg.memdesc<32x64xf8E5M2, #shared, #smem, mutable, 1x32x64>
      ttg.local_store %105, %118 : tensor<32x64xf8E5M2, #blocked> -> !ttg.memdesc<32x64xf8E5M2, #shared, #smem, mutable, 1x32x64>
      %119 = ttg.memdesc_index %64[%117] : !ttg.memdesc<1x64x32xf8E5M2, #shared1, #smem, mutable> -> !ttg.memdesc<64x32xf8E5M2, #shared1, #smem, mutable, 1x64x32>
      ttg.local_store %112, %119 : tensor<64x32xf8E5M2, #blocked1> -> !ttg.memdesc<64x32xf8E5M2, #shared1, #smem, mutable, 1x64x32>
      scf.yield %114, %95, %96, %117, %118, %119 : tensor<32x32xf32, #mma>, tensor<32x64xi32, #blocked>, tensor<64x32xi32, #blocked1>, i32, !ttg.memdesc<32x64xf8E5M2, #shared, #smem, mutable, 1x32x64>, !ttg.memdesc<64x32xf8E5M2, #shared1, #smem, mutable, 1x64x32>
    }
    %69 = arith.cmpi sge, %41, %c1_i32 : i32
    %70 = ttg.local_load %68#4 : !ttg.memdesc<32x64xf8E5M2, #shared, #smem, mutable, 1x32x64> -> tensor<32x64xf8E5M2, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 16}>>
    %71 = ttg.local_load %68#5 : !ttg.memdesc<64x32xf8E5M2, #shared1, #smem, mutable, 1x64x32> -> tensor<64x32xf8E5M2, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 16}>>
    %72 = scf.if %69 -> (tensor<32x32xf32, #mma>) {
      %95 = tt.dot_scaled %70, %71, %68#0 lhs = e5m2 rhs = e5m2 {fastMath = false} : tensor<32x64xf8E5M2, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 16}>> * tensor<64x32xf8E5M2, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 16}>> -> tensor<32x32xf32, #mma>
      scf.yield %95 : tensor<32x32xf32, #mma>
    } else {
      scf.yield %68#0 : tensor<32x32xf32, #mma>
    }
    %73 = arith.select %69, %72, %68#0 : tensor<32x32xf32, #mma>
    ttg.local_dealloc %64 : !ttg.memdesc<1x64x32xf8E5M2, #shared1, #smem, mutable>
    ttg.local_dealloc %63 : !ttg.memdesc<1x32x64xf8E5M2, #shared, #smem, mutable>
    %74 = arith.truncf %73 : tensor<32x32xf32, #mma> to tensor<32x32xf16, #mma>
    %75 = tt.expand_dims %22 {axis = 1 : i32} : tensor<32xi32, #ttg.slice<{dim = 1, parent = #mma}>> -> tensor<32x1xi32, #mma>
    %76 = arith.muli %arg8, %14 : i32
    %77 = tt.expand_dims %29 {axis = 0 : i32} : tensor<32xi32, #ttg.slice<{dim = 0, parent = #mma}>> -> tensor<1x32xi32, #mma>
    %78 = tt.expand_dims %16 {axis = 1 : i32} : tensor<32xi32, #ttg.slice<{dim = 1, parent = #mma}>> -> tensor<32x1xi32, #mma>
    %79 = tt.splat %arg8 : i32 -> tensor<32x1xi32, #mma>
    %80 = arith.muli %79, %78 : tensor<32x1xi32, #mma>
    %81 = tt.broadcast %80 : tensor<32x1xi32, #mma> -> tensor<32x32xi32, #mma>
    %82 = tt.expand_dims %18 {axis = 0 : i32} : tensor<32xi32, #ttg.slice<{dim = 0, parent = #mma}>> -> tensor<1x32xi32, #mma>
    %83 = tt.broadcast %82 : tensor<1x32xi32, #mma> -> tensor<32x32xi32, #mma>
    %84 = arith.addi %76, %25 : i32
    %85 = arith.addi %81, %83 : tensor<32x32xi32, #mma>
    %86 = tt.splat %84 : i32 -> tensor<32x32xi32, #mma>
    %87 = arith.addi %86, %85 : tensor<32x32xi32, #mma>
    %88 = tt.splat %arg3 : i32 -> tensor<32x1xi32, #mma>
    %89 = arith.cmpi slt, %75, %88 : tensor<32x1xi32, #mma>
    %90 = tt.splat %arg4 : i32 -> tensor<1x32xi32, #mma>
    %91 = arith.cmpi slt, %77, %90 : tensor<1x32xi32, #mma>
    %92 = tt.broadcast %89 : tensor<32x1xi1, #mma> -> tensor<32x32xi1, #mma>
    %93 = tt.broadcast %91 : tensor<1x32xi1, #mma> -> tensor<32x32xi1, #mma>
    %94 = arith.andi %92, %93 : tensor<32x32xi1, #mma>
    amdgpu.buffer_store %74, %arg2[%87], %94 : tensor<32x32xf16, #mma>
    tt.return
  }
}

{-#
  external_resources: {
    mlir_reproducer: {
      pipeline: "builtin.module(optimize-amd-lds-usage{lds-limit=0 target-arch=gfx950}, convert-scf-to-cf, gluon-inline, convert-index-to-llvm{index-bitwidth=0}, allocate-amdgpu-shared-memory, convert-triton-amdgpu-to-llvm{arch=gfx950 ftz=true}, canonicalize{  max-iterations=10 max-num-rewrites=-1 region-simplify=normal test-convergence=false top-down=true}, cse, convert-cf-to-llvm{index-bitwidth=0}, convert-arith-to-llvm{index-bitwidth=0}, canonicalize{  max-iterations=10 max-num-rewrites=-1 region-simplify=normal test-convergence=false top-down=true}, cse, symbol-dce, enable-line-info, convert-builtin-func-to-llvm{ftz=true})",
      disable_threading: true,
      verify_each: true
    }
  }
#-}
//tritonmm.py:91:0: error: Failures have been detected while processing an MLIR pass pipeline
//tritonmm.py:91:0: note: Pipeline failed while executing [ConvertTritonAMDGPUToLLVM on 'builtin.module' operation]: reproducer generated at std::errs, please share the reproducer above with Triton project.
Traceback (most recent call last):
  File "//tritonmm.py", line 231, in <module>
    triton_output = matmul(a, b)
                    ^^^^^^^^^^^^
  File "//tritonmm.py", line 197, in matmul
    matmul_kernel[grid](
  File "//venv/lib/python3.12/site-packages/triton/runtime/jit.py", line 359, in <lambda>
    return lambda *args, **kwargs: self.run(grid=grid, warmup=False, *args, **kwargs)
                                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "//venv/lib/python3.12/site-packages/triton/runtime/autotuner.py", line 240, in run
    benchmark()
  File "//venv/lib/python3.12/site-packages/triton/runtime/autotuner.py", line 229, in benchmark
    timings = {config: self._bench(*args, config=config, **kwargs) for config in pruned_configs}
                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "//venv/lib/python3.12/site-packages/triton/runtime/autotuner.py", line 164, in _bench
    return self.do_bench(kernel_call, quantiles=(0.5, 0.2, 0.8))
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "//venv/lib/python3.12/site-packages/triton/testing.py", line 149, in do_bench
    fn()
  File "//venv/lib/python3.12/site-packages/triton/runtime/autotuner.py", line 150, in kernel_call
    self.fn.run(
  File "//venv/lib/python3.12/site-packages/triton/runtime/jit.py", line 675, in run
    kernel = self._do_compile(key, signature, device, constexprs, options, attrs, warmup)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "//venv/lib/python3.12/site-packages/triton/runtime/jit.py", line 803, in _do_compile
    kernel = self.compile(src, target=target, options=options.__dict__)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "//venv/lib/python3.12/site-packages/triton/compiler/compiler.py", line 320, in compile
    next_module = compile_ir(module, metadata)
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "//venv/lib/python3.12/site-packages/triton/backends/amd/compiler.py", line 474, in <lambda>
    stages["llir"] = lambda src, metadata: self.make_llir(src, metadata, options)
                                           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "//venv/lib/python3.12/site-packages/triton/backends/amd/compiler.py", line 326, in make_llir
    pm.run(mod, 'make_llir')
RuntimeError: PassManager::run failed

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Based on this, would it be ok to include the change to matrix_instr_nonkdim = 32 ?

@matthiasdiener matthiasdiener force-pushed the enable-fp8-mm-tutorial-gfx950 branch 3 times, most recently from 9c1ef69 to 124fffa Compare September 29, 2025 21:29
This enables the float8_e5m2 part of the test for AMD's gfx950 devices,
which has gained support for this data type (see
https://rocm.docs.amd.com/en/latest/reference/precision-support.html).
This test was already enabled for CUDA devices.
@matthiasdiener matthiasdiener force-pushed the enable-fp8-mm-tutorial-gfx950 branch from 124fffa to c241663 Compare September 30, 2025 04:37
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants