Skip to content

Conversation

tkarna
Copy link
Owner

@tkarna tkarna commented Jul 29, 2025

xegpu transform ops for matrix multiplication

Purpose

This document outlines new transform.xegpu transform operations.

  • Currently in MLIR there's no mechanism to lower high-level linalg operations to the xegpu dialect, although such capability would be useful in a number of user applications. The proposed XeGPU transform operations aim to fill the gaps for lowering linalg.matmul operations. They also address necessary tiling, prefetching etc. optimizations necessary to achieve good performance on Xe GPUs. Going forward, the XeGPU transform ops can be extended to support more workloads.
  • The transform ops expose a number of parameters, such as tile sizes, as tunable knobs for autotuning and code generation applications.
  • The transform ops provide a more fine-grained mechanism for defining lowering schedules compared to conventional MLIR pass pipelines. Specifically, the transform dialect introduces handles to payload operations as SSA values (e.g., scf.for op) which allows defining differentiated transforms for each op (e.g., a main loop and remainder loop after tiling).

New Operations

The new transform ops are:

  • transform.xegpu.set_operand_layout: Given a handle to an anchor op, like xegpu.dpas, sets xegpu.layout attributes to its operands. Currently only supports DPAS ops. DPAS op must have been tiled to workgroup (WG) size, and reduction loop K size. This op sets the sg_layout, sg_data and inst_data layout attributes. * transform.xegpu.insert_prefetch: Inserts prefetch operations for an xegpu op operands. Currently only supports DPAS op. Sets sg_layout, sg_data attributes, emits prefetch ops, and inserts them in the reduction loop.
  • transform.xegpu.hoist_desc_ops: Hoists xegpu.create_nd_desc ops out of the loop.
  • transform.xegpu.set_gpu_launch_threads: Given a handle to a gpu.launch op, sets the number of gpu threads. This op is a workaround to ensure correct number of threads in the launch op.

Example: 4k matrix multiplication payload

Consider the following 4k linalg.matmul payload function defined with tensors.

module {
  func.func @run(%arg0: tensor<4096x4096xf16>, %arg1: tensor<4096x4096xf16>,
                    %arg2: tensor<4096x4096xf16>) -> tensor<4096x4096xf16> {
    %0 = linalg.matmul ins(%arg0, %arg1 : tensor<4096x4096xf16>, tensor<4096x4096xf16>)
                  outs(%arg2 : tensor<4096x4096xf16>) -> tensor<4096x4096xf16>
    return %0 : tensor<4096x4096xf16>
  }
}

Applying existing transforms

We can apply workgroup (WG) and reduction dimension (K) tiling using the following upstream transform operations on the matched linalg.matmul op handle:

module attributes {transform.with_named_sequence} {
  transform.named_sequence @__transform_main(%arg0: !transform.any_op) {
    %0 = transform.structured.match ops{["linalg.matmul"]} in %arg0 : (!transform.any_op) -> !transform.any_op
    // WG tiling
    %wg_matmul, %loop_wg = transform.structured.tile_using_forall %0 tile_sizes [256, 256] : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
    // tile k-dimension
    %wgk_matmul, %loop_k = transform.structured.tile_using_for %wg_matmul tile_sizes [0, 0, 32] : (!transform.any_op) -> (!transform.any_op, !transform.any_op)

    transform.yield
  }
}

This produces an scf.forall loop for the WG tiling, followed by an scf.for reduction loop. The matmul op has shape (256x32, 32x256) -> 256x256.

func.func @run(%arg0: tensor<4096x4096xf16>, %arg1: tensor<4096x4096xf16>, %arg2: tensor<4096x4096xf16>) -> tensor<4096x4096xf16> {
  %c32 = arith.constant 32 : index
  %c4096 = arith.constant 4096 : index
  %c0 = arith.constant 0 : index
  %c256 = arith.constant 256 : index
  %0 = scf.forall (%arg3, %arg4) in (16, 16) shared_outs(%arg5 = %arg2) -> (tensor<4096x4096xf16>) {
    %1 = arith.muli %arg3, %c256 overflow<nsw> : index
    %2 = arith.muli %arg4, %c256 overflow<nsw> : index
    %extracted_slice = tensor.extract_slice %arg0[%1, 0] [256, 4096] [1, 1] : tensor<4096x4096xf16> to tensor<256x4096xf16>
    %extracted_slice_0 = tensor.extract_slice %arg1[0, %2] [4096, 256] [1, 1] : tensor<4096x4096xf16> to tensor<4096x256xf16>
    %extracted_slice_1 = tensor.extract_slice %arg5[%1, %2] [256, 256] [1, 1] : tensor<4096x4096xf16> to tensor<256x256xf16>
    %3 = scf.for %arg6 = %c0 to %c4096 step %c32 iter_args(%arg7 = %extracted_slice_1) -> (tensor<256x256xf16>) {
      %extracted_slice_2 = tensor.extract_slice %extracted_slice[0, %arg6] [256, 32] [1, 1] : tensor<256x4096xf16> to tensor<256x32xf16>
      %extracted_slice_3 = tensor.extract_slice %extracted_slice_0[%arg6, 0] [32, 256] [1, 1] : tensor<4096x256xf16> to tensor<32x256xf16>
      %4 = linalg.matmul ins(%extracted_slice_2, %extracted_slice_3 : tensor<256x32xf16>, tensor<32x256xf16>) outs(%arg7 : tensor<256x256xf16>) -> tensor<256x256xf16>
      scf.yield %4 : tensor<256x256xf16>
    }
    scf.forall.in_parallel {
      tensor.parallel_insert_slice %3 into %arg5[%1, %2] [256, 256] [1, 1] : tensor<256x256xf16> into tensor<4096x4096xf16>
    }
  }
  return %0 : tensor<4096x4096xf16>
}

We can now vectorize the linalg.matmul op and hoist the loop-invariant C tile read/store ops. Hoisting can be safely applied as we are working on tensors, thus avoiding any memory side-effects.

    // vectorize, applies to all ops in the func
    %func = transform.structured.match ops{["func.func"]} in %arg0 : (!transform.any_op) -> !transform.any_op
    %func2 = transform.structured.vectorize_children_and_apply_patterns %func
      : (!transform.any_op) -> !transform.any_op

    // hoist loop invariant vector read/store ops
    %loop_k2 = transform.structured.match ops{["scf.for"]} in %func2 : (!transform.any_op) -> !transform.any_op
    transform.loop.hoist_loop_invariant_subsets %loop_k2 : !transform.any_op

Next we bufferize the payload function and drop the redundant function return value.

    // bufferize
    %payload_mod = transform.get_parent_op %func2 {op_name = "builtin.module"} : (!transform.any_op) -> !transform.any_op
    %payload_mod2 = transform.bufferization.one_shot_bufferize layout{IdentityLayoutMap} %payload_mod  {bufferize_function_boundaries=true, allow_return_allocs_from_loops=true} : (!transform.any_op) -> !transform.any_op
    %payload_mod3 = transform.apply_registered_pass "drop-equivalent-buffer-results" to %payload_mod2 : (!transform.any_op) -> !transform.any_op

The matrix multiplication is now defined with the vector ops and memrefs.

#map = affine_map<(d0, d1, d2) -> (d0, d2)>
#map1 = affine_map<(d0, d1, d2) -> (d2, d1)>
#map2 = affine_map<(d0, d1, d2) -> (d0, d1)>
func.func @run(%arg0: memref<4096x4096xf16>, %arg1: memref<4096x4096xf16>, %arg2: memref<4096x4096xf16>) {
  %c256 = arith.constant 256 : index
  %cst = arith.constant 0.000000e+00 : f16
  %c32 = arith.constant 32 : index
  %c4096 = arith.constant 4096 : index
  %c0 = arith.constant 0 : index
  scf.forall (%arg3, %arg4) in (16, 16) {
    %0 = arith.muli %arg3, %c256 overflow<nsw> : index
    %1 = arith.muli %arg4, %c256 overflow<nsw> : index
    %subview = memref.subview %arg2[%0, %1] [256, 256] [1, 1] : memref<4096x4096xf16> to memref<256x256xf16, strided<[4096, 1], offset: ?>>
    %2 = vector.transfer_read %subview[%c0, %c0], %cst {in_bounds = [true, true]} : memref<256x256xf16, strided<[4096, 1], offset: ?>>, vector<256x256xf16>
    %3 = scf.for %arg5 = %c0 to %c4096 step %c32 iter_args(%arg6 = %2) -> (vector<256x256xf16>) {
      %4 = vector.transfer_read %arg0[%0, %arg5], %cst {in_bounds = [true, true]} : memref<4096x4096xf16>, vector<256x32xf16>
      %5 = vector.transfer_read %arg1[%arg5, %1], %cst {in_bounds = [true, true]} : memref<4096x4096xf16>, vector<32x256xf16>
      %6 = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %4, %5, %arg6 : vector<256x32xf16>, vector<32x256xf16> into vector<256x256xf16>
      scf.yield %6 : vector<256x256xf16>
    }
    vector.transfer_write %3, %subview[%c0, %c0] {in_bounds = [true, true]} : vector<256x256xf16>, memref<256x256xf16, strided<[4096, 1], offset: ?>>
  }
  return
}

We can now apply existing gpu dialect passes to map this loop nest to gpu blocks and treads (WG and SG). We first convert the scf.forall loop to scf.parallel. The gpu-map-parallel-loops expects two scf.parallel loops, one for WG and one for SG level. At this stage, however, we only have the WG loop, so the pass assumes a single GPU thread. We will fix this later.

    // convert scf.forall to scf.parallel
    %loop_wg2 = transform.structured.match ops{["scf.forall"]} in %payload_mod3 : (!transform.any_op) -> !transform.any_op
    %loop_wg3 = transform.loop.forall_to_parallel %loop_wg2 : (!transform.any_op) -> !transform.any_op

    // apply passes
    %func3 = transform.structured.match ops{["func.func"]} in %payload_mod3 : (!transform.any_op) -> !transform.any_op
    %func4 = transform.apply_registered_pass "gpu-map-parallel-loops" to %func3 : (!transform.any_op) -> !transform.any_op
    %func5 = transform.apply_registered_pass "convert-parallel-loops-to-gpu" to %func4 : (!transform.any_op) -> !transform.any_op
    %func6 = transform.apply_registered_pass "lower-affine" to %func5 : (!transform.any_op) -> !transform.any_op

We can now apply the convert-vector-to-xegpu pass to convert the vector dialect ops to xegpu ops and fold memref.subview ops into the xegpu descriptor op.

    %func7 = transform.apply_registered_pass "convert-vector-to-xegpu" to %func6 : (!transform.any_op) -> !transform.any_op
    // fold duplicated xegpu desc ops
    transform.apply_cse to %func7 : !transform.any_op
    %func8 = transform.apply_registered_pass "xegpu-fold-alias-ops" to %func7 : (!transform.any_op) -> !transform.any_op

The reduction loop now reads:

    ...
    %2 = xegpu.create_nd_tdesc %arg2[%0, %1] : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x256xf16, #xegpu.block_tdesc_attr<memory_space =  global, array_length = 1 : i64, boundary_check = false>>
    %3 = xegpu.load_nd %2  : !xegpu.tensor_desc<256x256xf16, #xegpu.block_tdesc_attr<memory_space =  global, array_length = 1 : i64, boundary_check = false>> -> vector<256x256xf16>
    %4 = scf.for %arg15 = %c0 to %c4096 step %c32 iter_args(%arg16 = %3) -> (vector<256x256xf16>) {
      %5 = xegpu.create_nd_tdesc %arg0[%0, %arg15] : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x32xf16, #xegpu.block_tdesc_attr<memory_space =  global, array_length = 1 : i64, boundary_check = false>>
      %6 = xegpu.load_nd %5  : !xegpu.tensor_desc<256x32xf16, #xegpu.block_tdesc_attr<memory_space =  global, array_length = 1 : i64, boundary_check = false>> -> vector<256x32xf16>
      %7 = xegpu.create_nd_tdesc %arg1[%arg15, %1] : memref<4096x4096xf16> -> !xegpu.tensor_desc<32x256xf16, #xegpu.block_tdesc_attr<memory_space =  global, array_length = 1 : i64, boundary_check = false>>
      %8 = xegpu.load_nd %7  : !xegpu.tensor_desc<32x256xf16, #xegpu.block_tdesc_attr<memory_space =  global, array_length = 1 : i64, boundary_check = false>> -> vector<32x256xf16>
      %9 = xegpu.dpas %6, %8, %arg16 : vector<256x32xf16>, vector<32x256xf16>, vector<256x256xf16> -> vector<256x256xf16>
      scf.yield %9 : vector<256x256xf16>
    }
    xegpu.store_nd %4, %2  : vector<256x256xf16>, !xegpu.tensor_desc<256x256xf16, #xegpu.block_tdesc_attr<memory_space =  global, array_length = 1 : i64, boundary_check = false>>
    ...

Applying xegpu transform ops

The above xegpu IR must be further optimized to get good performance. This is where the new xegpu transform ops come to play.

The transform.xegpu.set_operand_layout operation

The DPAS op is defined at the WG level without any indication on how it should be distributed to the subgroups. To this end, we apply the transform.xegpu.set_operand_layout op which sets the xegpu.layout attributes. We first match the DPAS op, and then apply the desired sg_layout, sg_data, and inst_data attributes for the A tile (operand index = 0):

%loop_k3 = transform.structured.match ops{["scf.for"]} in %func8 : (!transform.any_op) -> !transform.any_op
%dpas_op = transform.structured.match ops{["xegpu.dpas"]} in %loop_k3 : (!transform.any_op) -> !transform.any_op
transform.xegpu.set_operand_layout %dpas_op index = 0 sg_layout = [8, 4] sg_data = [32, 32] inst_data = [8, 16] : !transform.any_op

The B and C tiles are handled analogously:

transform.xegpu.set_operand_layout %dpas_op index = 1 sg_layout = [8, 4] sg_data = [32, 64] inst_data = [16, 16] : !transform.any_op
transform.xegpu.set_operand_layout %dpas_op index = 2 sg_layout = [8, 4] sg_data = [32, 64] inst_data = [8, 16] : !transform.any_op

Setting the layout to the C tile also sets the layout_result_0 attribute to the xegpu.dpas op. The final reduction loop with layout attributes is:

  %2 = xegpu.create_nd_tdesc %arg2[%0, %1] : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x256xf16, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 64], inst_data = [8, 16]>>
  %3 = xegpu.load_nd %2  : !xegpu.tensor_desc<256x256xf16, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 64], inst_data = [8, 16]>> -> vector<256x256xf16>
  %4 = scf.for %arg15 = %c0 to %c4096 step %c32 iter_args(%arg16 = %3) -> (vector<256x256xf16>) {
    %5 = xegpu.create_nd_tdesc %arg0[%0, %arg15] : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x32xf16, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], inst_data = [32, 16]>>
    %6 = xegpu.load_nd %5  : !xegpu.tensor_desc<256x32xf16, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], inst_data = [32, 16]>> -> vector<256x32xf16>
    %7 = xegpu.convert_layout %6 <{input_layout = #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], inst_data = [32, 16]>, target_layout = #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], inst_data = [8, 16]>}> : vector<256x32xf16>
    %8 = xegpu.create_nd_tdesc %arg1[%arg15, %1] : memref<4096x4096xf16> -> !xegpu.tensor_desc<32x256xf16, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 64], inst_data = [32, 16]>>
    %9 = xegpu.load_nd %8  : !xegpu.tensor_desc<32x256xf16, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 64], inst_data = [32, 16]>> -> vector<32x256xf16>
    %10 = xegpu.convert_layout %9 <{input_layout = #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 64], inst_data = [32, 16]>, target_layout = #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 64], inst_data = [16, 16]>}> : vector<32x256xf16>
    %11 = xegpu.dpas %7, %10, %arg16 {layout_result_0 = #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 64], inst_data = [8, 16]>} : vector<256x32xf16>, vector<32x256xf16>, vector<256x256xf16> -> vector<256x256xf16>
    scf.yield %11 : vector<256x256xf16>
  }
  xegpu.store_nd %4, %2  : vector<256x256xf16>, !xegpu.tensor_desc<256x256xf16, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 64], inst_data = [8, 16]>>

The transform.xegpu.hoist_desc_ops operation

The above IR still has the A and B descriptor ops within the reduction loop. These can be hoisted with the transform.xegpu.hoist_desc_ops op:

%loop_k4 = transform.xegpu.hoist_desc_ops %loop_k3 : (!transform.any_op) -> !transform.any_op

The descriptor op is moved out of the loop, adding the descriptor to the loop's iter_args and adding an offset update op in the loop.

...
%4 = xegpu.create_nd_tdesc %arg0[%0, %c0] : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x32xf16, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], inst_data = [32, 16]>>
%5 = xegpu.create_nd_tdesc %arg1[%c0, %1] : memref<4096x4096xf16> -> !xegpu.tensor_desc<32x256xf16, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 64], inst_data = [32, 16]>>
%6:3 = scf.for %arg15 = %c0 to %c4096 step %c32 iter_args(..., %arg17 = %4, %arg18 = %5) -> (..., !xegpu.tensor_desc<256x32xf16, ...>, !xegpu.tensor_desc<32x256xf16, ...>) {
  %7 = xegpu.update_nd_offset %arg17, [0, %c32] : !xegpu.tensor_desc<256x32xf16, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], inst_data = [32, 16]>>
  %8 = xegpu.load_nd %arg17  : !xegpu.tensor_desc<256x32xf16, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], inst_data = [32, 16]>> -> vector<256x32xf16>
  ...
  %10 = xegpu.update_nd_offset %arg18, [%c32, 0] : !xegpu.tensor_desc<32x256xf16, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 64], inst_data = [32, 16]>>
  %11 = xegpu.load_nd %arg18  : !xegpu.tensor_desc<32x256xf16, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 64], inst_data = [32, 16]>> -> vector<32x256xf16>
  ...
  scf.yield ..., %7, %10 : ..., !xegpu.tensor_desc<256x32xf16, ...>, !xegpu.tensor_desc<32x256xf16, ...>
}

This op replaces the scf.for op and therefore the loop handle is invalidated and an another handle to the new loop is returned.

The resulting IR can now lowered further using the xegpu-wg-to-sg-distribute and xegpu-blocking passes.

The transform.xegpu.insert_prefetch operation

Cooperative prefetching can be added using the transform.xegpu.insert_prefetch op. The op takes a handle to the reduction loop and the DPAS op whose operands we want to prefetch. For the A tile, we prefetch the 256x32 tile using 32 threads along the first dimension, i.e. each thread fetches a 8x32 tile:

%dpas_op2, %loop_k4 = transform.xegpu.insert_prefetch %dpas_op %loop_k3 index = 0 sg_layout = [32, 1] sg_data = [8, 32] : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)

This emits the descriptor, update offset and prefetch ops in the reduction loop:

...
%4 = xegpu.create_nd_tdesc %arg0[%0, %c0] : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x32xf16, #xegpu.layout<sg_layout = [32, 1], sg_data = [8, 32]>>
%5 = xegpu.update_nd_offset %4, [0, %c32] : !xegpu.tensor_desc<256x32xf16, #xegpu.layout<sg_layout = [32, 1], sg_data = [8, 32]>>
xegpu.prefetch_nd %4 <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<cached>, l3_hint = #xegpu.cache_hint<cached>}> : !xegpu.tensor_desc<256x32xf16, #xegpu.layout<sg_layout = [32, 1], sg_data = [8, 32]>>
%6:2 = scf.for %arg15 = %c0 to %c4096 step %c32 iter_args(%arg16 = %5, ...) -> (!xegpu.tensor_desc<256x32xf16, #xegpu.layout<sg_layout = [32, 1], sg_data = [8, 32]>>, ...) {
  %7 = xegpu.update_nd_offset %arg16, [0, %c32] : !xegpu.tensor_desc<256x32xf16, #xegpu.layout<sg_layout = [32, 1], sg_data = [8, 32]>>
  xegpu.prefetch_nd %arg16 <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<cached>, l3_hint = #xegpu.cache_hint<cached>}> : !xegpu.tensor_desc<256x32xf16, #xegpu.layout<sg_layout = [32, 1], sg_data = [8, 32]>>
  ...
  scf.yield %7, ... : !xegpu.tensor_desc<256x32xf16, #xegpu.layout<sg_layout = [32, 1], sg_data = [8, 32]>>, ...
}

The B tile prefetches are handled analogously. Here we choose to prefetch the 32x256 tile using 32 threads in [8, 4] layout, each thread fetching again a 8x32 tile:

%dpas_op3, %loop_k5 = transform.xegpu.insert_prefetch %dpas_op2 %loop_k4 index = 1 sg_layout = [4, 8] sg_data = [8, 32] : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)

The transform.xegpu.set_gpu_launch_threads operation

Finally, we fix the number of treads in the gpu.launch op with the following op:

%launch_op = transform.structured.match ops{["gpu.launch"]} in %func8 : (!transform.any_op) -> !transform.any_op
transform.xegpu.set_gpu_launch_threads %launch_op threads = [8, 4, 1] : !transform.any_op

Full lowering schedule

Combining the above transformations we can now write the full lowering schedule for the matmul operation:

module {
  func.func @run(%arg0: tensor<4096x4096xf16>, %arg1: tensor<4096x4096xf16>,
                    %arg2: tensor<4096x4096xf16>) -> tensor<4096x4096xf16> {
    %0 = linalg.matmul ins(%arg0, %arg1 : tensor<4096x4096xf16>, tensor<4096x4096xf16>)
                  outs(%arg2 : tensor<4096x4096xf16>) -> tensor<4096x4096xf16>
    return %0 : tensor<4096x4096xf16>
  }
}
module attributes {transform.with_named_sequence} {
  transform.named_sequence @__transform_main(%arg0: !transform.any_op) {
    %0 = transform.structured.match ops{["linalg.matmul"]} in %arg0 : (!transform.any_op) -> !transform.any_op

    // WG tiling
    %wg_matmul, %loop_wg = transform.structured.tile_using_forall %0 tile_sizes [256, 256] : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
    // tile k-dimension
    %wgk_matmul, %loop_k = transform.structured.tile_using_for %wg_matmul tile_sizes [0, 0, 32] : (!transform.any_op) -> (!transform.any_op, !transform.any_op)

    // vectorize
    %func = transform.structured.match ops{["func.func"]} in %arg0 : (!transform.any_op) -> !transform.any_op
    %func2 = transform.structured.vectorize_children_and_apply_patterns %func
      : (!transform.any_op) -> !transform.any_op

    // hoist loop invariant vector read/store ops
    %loop_k2 = transform.structured.match ops{["scf.for"]} in %func2 : (!transform.any_op) -> !transform.any_op
    transform.loop.hoist_loop_invariant_subsets %loop_k2 : !transform.any_op

    // bufferize
    %payload_mod = transform.get_parent_op %func2 {op_name = "builtin.module"} : (!transform.any_op) -> !transform.any_op
    %payload_mod2 = transform.bufferization.one_shot_bufferize layout{IdentityLayoutMap} %payload_mod  {bufferize_function_boundaries=true, allow_return_allocs_from_loops=true} : (!transform.any_op) -> !transform.any_op
    %payload_mod3 = transform.apply_registered_pass "drop-equivalent-buffer-results" to %payload_mod2 : (!transform.any_op) -> !transform.any_op

    // convert scf.forall to scf.parallel
    %loop_wg2 = transform.structured.match ops{["scf.forall"]} in %payload_mod3 : (!transform.any_op) -> !transform.any_op
    %loop_wg3 = transform.loop.forall_to_parallel %loop_wg2 : (!transform.any_op) -> !transform.any_op

    // apply passes
    %func3 = transform.structured.match ops{["func.func"]} in %payload_mod3 : (!transform.any_op) -> !transform.any_op
    %func4 = transform.apply_registered_pass "gpu-map-parallel-loops" to %func3 : (!transform.any_op) -> !transform.any_op
    %func5 = transform.apply_registered_pass "convert-parallel-loops-to-gpu" to %func4 : (!transform.any_op) -> !transform.any_op
    %func6 = transform.apply_registered_pass "lower-affine" to %func5 : (!transform.any_op) -> !transform.any_op

    // canonicalize
    transform.apply_cse to %func6 : !transform.any_op
    transform.apply_patterns to %func6 {
      transform.apply_patterns.canonicalization
    } : !transform.any_op

    %func7 = transform.apply_registered_pass "convert-vector-to-xegpu" to %func6 : (!transform.any_op) -> !transform.any_op
    // fold duplicated xegpu desc ops
    transform.apply_cse to %func7 : !transform.any_op
    %func8 = transform.apply_registered_pass "xegpu-fold-alias-ops" to %func7 : (!transform.any_op) -> !transform.any_op

    %loop_k3 = transform.structured.match ops{["scf.for"]} in %func8 : (!transform.any_op) -> !transform.any_op

    // Add layouts to DPAS op operands
    %dpas_op = transform.structured.match ops{["xegpu.dpas"]} in %loop_k3 : (!transform.any_op) -> !transform.any_op

    transform.xegpu.set_operand_layout %dpas_op index = 0 sg_layout = [8, 4] sg_data = [32, 32] inst_data = [8, 16] : !transform.any_op
    transform.xegpu.set_operand_layout %dpas_op index = 1 sg_layout = [8, 4] sg_data = [32, 64] inst_data = [16, 16] : !transform.any_op
    transform.xegpu.set_operand_layout %dpas_op index = 2 sg_layout = [8, 4] sg_data = [32, 64] inst_data = [8, 16] : !transform.any_op

    // Insert prefetch ops for DPAS A and B tiles. Should be applied before hoisting desc ops.
    %dpas_op2, %loop_k4 = transform.xegpu.insert_prefetch %dpas_op %loop_k3 index = 0 sg_layout = [32, 1] sg_data = [8, 32] : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
    %dpas_op3, %loop_k5 = transform.xegpu.insert_prefetch %dpas_op2 %loop_k4 index = 1 sg_layout = [4, 8] sg_data = [8, 32] : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)

    // Hoist desc ops out of reduction loop.
    %loop_k6 = transform.xegpu.hoist_desc_ops %loop_k5 : (!transform.any_op) -> !transform.any_op

    transform.apply_patterns to %func8 {
      transform.apply_patterns.canonicalization
    } : !transform.any_op
    transform.apply_cse to %func8 : !transform.any_op

    // Set correct number of gpu threads.
    // NOTE should use gpu transform ops to do this, instead of convert-parallel-loops-to-gpu pass
    %launch_op = transform.structured.match ops{["gpu.launch"]} in %func8 : (!transform.any_op) -> !transform.any_op
    transform.xegpu.set_gpu_launch_threads %launch_op threads = [8, 4, 1] : !transform.any_op

    // Outline gpu func.
    %func9 = transform.apply_registered_pass "gpu-launch-sink-index-computations" to %func8 : (!transform.any_op) -> !transform.any_op
    %payload_mod4 = transform.apply_registered_pass "gpu-kernel-outlining" to %payload_mod3 : (!transform.any_op) -> !transform.any_op

    transform.apply_patterns to %payload_mod4 {
      transform.apply_patterns.canonicalization
    } : !transform.any_op
    transform.apply_cse to %payload_mod4 : !transform.any_op

    transform.yield
  }
}

The above schedule exposes the following parameters:

  • WG tile size: [256, 256]
  • SG tile size: [32, 64]
  • K tile size: 32
  • DPAS tile sizes [8, 16, 16]
  • Prefetch tile sizes for A and B: [8, 32], [8, 32]

The output IR after the above schedule has been applied can be found here (now outdated).

Performance

The above schedule yields ~200 TFLOPS/s performance on a single PVC tile and passes correctness test.

Discussion / Future work

  • Generalize xegpu.set_operand_layout and xegpu.insert_prefetch ops to support other ops than xegpu.dpas op.
  • Add support to handle layout conversions, for example different inst_data tile between a load and use. In the long term, we could have xegpu.set_operand_layout and xegpu.set_result_layout ops that set attrs for individual ops and use the XeGPU layout propagation mechanism (under development) to handle layout conversions.
  • The xegpu.set_gpu_launch_threads should be handled differently in the future, preferably using suitable gpu dialect transform ops. It is included for the time being so that the IR can be executed correctly.
  • The above transform schedule, or parts of it, could be exposed as a pass if needed. The full pass could have a signature like linalg-matmul-to-xegpu{wg-tile=256,256 sg-tile=32,64 k-tile=32 dpas-tile=8,16,16 a-prefetch=8,32 b-prefetch=8,32 a-load=32,16 b-load=32,16}. This pass applies the same transforms to all DPAS ops.

@tkarna tkarna force-pushed the tkarna/xegpu-transform-ops branch from 1c0906a to df1b9a3 Compare July 30, 2025 15:51
Copy link
Collaborator

@chencha3 chencha3 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I had an initial pass. I will go through it again later. It currently looks to me that we need to generalize the layout setting. Currently, the implementation is limited to support a few specific cases only. Can we run analysis inside __transform_main, and query the analysis result for each Op or Value?

}

auto sgLayout = getSgLayout();
if (sgLayout.size() != 2) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what is the reason to limit the rank to be 2?

Copy link
Owner Author

@tkarna tkarna Aug 5, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So far I have only considered 2D inputs, namely 2D matmul op. This can be generalized as needed. My goal here was to demonstrate the transform ops with a 2D matmul and use that as the first CI test. Generalization can be added in the same PR or in a follow up with more tests.

@tkarna
Copy link
Owner Author

tkarna commented Aug 5, 2025

Can we run analysis inside __transform_main, and query the analysis result for each Op or Value?

Why would you need such analysis? Normally, I think, it is sufficient to inspect the payload op handle and transform op arguments for, say, verification purposes.

@tkarna tkarna requested a review from Jianhui-Li August 6, 2025 18:20
@chencha3
Copy link
Collaborator

chencha3 commented Aug 6, 2025

Can we run analysis inside __transform_main, and query the analysis result for each Op or Value?

Why would you need such analysis? Normally, I think, it is sufficient to inspect the payload op handle and transform op arguments for, say, verification purposes.

I mean from transform perspective, how to systematically assign layouts to each OpResult and OpOperand in a kernel

@tkarna tkarna force-pushed the tkarna/xegpu-transform-ops branch from df1b9a3 to 8b11bfd Compare August 8, 2025 15:43
@tkarna
Copy link
Owner Author

tkarna commented Aug 8, 2025

Updates:

  • Renamed xegpu.set_dpas_layout to more generic set_operand_layout.
  • Removed load_data argument from set_operand_layout op. We will address layout conversions in a follow-up PR.
  • Removed all references to DPAS op from the code where possible.
  • Removed XeGPU prefix from op names, e.g. XeGPUInsertPrefetchOp -> InsertPrefetchOp. This is useful for python bindings where syntax is xegpu.InsertPrefetchOp(...).
  • Added python bindings for xegpu transform ops, and tests.


let summary = "Hoists xegpu tile descriptor ops outside the containing loop";
let description = [{
Hoists `xepu.create_nd_tdesc` out of the loop. If the
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This pass may become unnecessary as we are transitioning to a new create_nd_tdesc definition: nd_tdesc created without offset and move offset to load_nd. Create_nd_tdesc would become loop_invariant.
Referring to this PRs:
a.1. make offset option for create_nd_tdesc (llvm#148335)

a.2. add optional offsets for load_nd and store_nd/prefetch_nd. (llvm#149424)

You may look at Imex innersource github issue#1151 for more background info.

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, yes I'm aware of this planned change. It implies some changes to the transform ops - it should in fact make the logic simpler in most cases. Hoisting the desc ops is still needed but indeed we might be able to use existing hoist patterns instead of an xegpu specific method. We can address this issue once the new load_nd-offset pipeline is complete. In the meantime, on my behalf, we could upstream these transform ops so that we can support linalg.matmul lowering.


let summary = "Adds xegpu prefetch ops to matmul operand tiles.";
let description = [{
Given an xegpu operation residing in a `scf.for` loop, this transform inserts cooperative `xegpu.prefetch` operations for the A (index = 0) or B (index = 1) operand. The prefetch tile size is determined by the `sg_layout` and `sg_data` attributes.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you mean the input is a xegpu DPAS op?

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, the implementation only supports DPAS op at the moment.

auto layoutAttr =
createLayoutAttr(rewriter.getContext(), sgLayout, sgData, instData);
descOp = setDescLayout(rewriter, descOp, layoutAttr);
if (operandIndex == 2) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is current implementation still assume the operation is dpasOp? if so, maybe you can add a TODO note.

tkarna pushed a commit that referenced this pull request Aug 13, 2025
Extend support in LLDB for WebAssembly. This PR adds a new Process
plugin (ProcessWasm) that extends ProcessGDBRemote for WebAssembly
targets. It adds support for WebAssembly's memory model with separate
address spaces, and the ability to fetch the call stack from the
WebAssembly runtime.

I have tested this change with the WebAssembly Micro Runtime (WAMR,
https://github.com/bytecodealliance/wasm-micro-runtime) which implements
a GDB debug stub and supports the qWasmCallStack packet.

```
(lldb) process connect --plugin wasm connect://localhost:4567
Process 1 stopped
* thread #1, name = 'nobody', stop reason = trace
    frame #0: 0x40000000000001ad
wasm32_args.wasm`main:
->  0x40000000000001ad <+3>:  global.get 0
    0x40000000000001b3 <+9>:  i32.const 16
    0x40000000000001b5 <+11>: i32.sub
    0x40000000000001b6 <+12>: local.set 0
(lldb) b add
Breakpoint 1: where = wasm32_args.wasm`add + 28 at test.c:4:12, address = 0x400000000000019c
(lldb) c
Process 1 resuming
Process 1 stopped
* thread #1, name = 'nobody', stop reason = breakpoint 1.1
    frame #0: 0x400000000000019c wasm32_args.wasm`add(a=<unavailable>, b=<unavailable>) at test.c:4:12
   1    int
   2    add(int a, int b)
   3    {
-> 4        return a + b;
   5    }
   6
   7    int
(lldb) bt
* thread #1, name = 'nobody', stop reason = breakpoint 1.1
  * frame #0: 0x400000000000019c wasm32_args.wasm`add(a=<unavailable>, b=<unavailable>) at test.c:4:12
    frame #1: 0x40000000000001e5 wasm32_args.wasm`main at test.c:12:12
    frame llvm#2: 0x40000000000001fe wasm32_args.wasm
```

This PR is based on an unmerged patch from Paolo Severini:
https://reviews.llvm.org/D78801. I intentionally stuck to the
foundations to keep this PR small. I have more PRs in the pipeline to
support the other features/packets.

My motivation for supporting Wasm is to support debugging Swift compiled
to WebAssembly:
https://www.swift.org/documentation/articles/wasm-getting-started.html
@tkarna
Copy link
Owner Author

tkarna commented Aug 13, 2025

Upstreaming these ops is deferred due to the ongoing changes in the xegpu dialect. Closing.

@tkarna tkarna closed this Aug 13, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants