Skip to content

Conversation

@dhernandez0
Copy link
Contributor

Motivation

Add prefetch functionality for gfx1250.

Technical Details

  • Added rock.threadwise_prefetch (similar to threadwise_read_into)
  • Added rock.global_prefetch (similar to global load), it lowers into memref.prefetch, which correctly uses global_prefetch_b8 on gfx1250
  • Added rock.threadwise_prefetch call in load_tile, just after loading the current tile, we start prefetching the next tile.

Note there's no need to conditionally call prefetch, as memref.prefetch operation is lowered to a no-op on targets that do not support hardware prefetching.

There are some pending things that will need to be done once we can measure performance:

  • measure if prefetching improves performance
  • what happens if we move prefetch to happen later

Our compilation fails for gfx1250 currently, so one test has a TODO until that is solved.

Test Plan

Added tests for rock.threadwise_prefetch and rock.threadwise_prefetch

Test Result

Tests pass.

Submission Checklist

Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

This PR introduces support for prefetching operations in the Rock dialect to improve memory access performance. The main additions include new global_prefetch and threadwise_prefetch operations with corresponding lowering passes and test coverage.

  • Adds GlobalPrefetchOp and ThreadwisePrefetchOp operations to the Rock dialect
  • Implements lowering patterns for prefetch operations through the compilation pipeline
  • Adds support for gfx1250 architecture in the AMD architecture database
  • Includes comprehensive test coverage for prefetch assembly generation and operation lowering

Reviewed Changes

Copilot reviewed 12 out of 12 changed files in this pull request and generated 1 comment.

Show a summary per file
File Description
mlir/include/mlir/Dialect/Rock/IR/RockOps.td Defines new global_prefetch and threadwise_prefetch operations
mlir/lib/Dialect/Rock/IR/RockDialect.cpp Implements verification logic for new prefetch operations
mlir/lib/Dialect/Rock/Transforms/SugarToLoops.cpp Adds lowering pattern for GlobalPrefetchOp to memref.prefetch
mlir/lib/Dialect/Rock/Transforms/ThreadwiseGemmLowering.cpp Implements lowering pattern for ThreadwisePrefetchOp
mlir/lib/Dialect/Rock/Transforms/BlockwiseLoadTileToThreadwise.cpp Integrates prefetching into blockwise load operations
mlir/lib/Dialect/Rock/Transforms/AlignTiling.cpp Adds handling for ThreadwisePrefetchOp in tracing logic
mlir/lib/Dialect/Rock/IR/AmdArchDb.cpp Adds gfx1250 architecture information
mlir/test/rocmlir-driver/prefetch_assembly.mlir Tests prefetch instruction generation for different GPU architectures
mlir/test/Dialect/Rock/lowering_threadwise_prefetch.mlir Tests threadwise prefetch lowering
mlir/test/Dialect/Rock/lowering_global_prefetch.mlir Tests global prefetch lowering
mlir/test/Dialect/Rock/gridwise_attention_accel_lowering.mlir Updates existing test to account for new prefetch operations
mlir/test/Dialect/Rock/lowering_global_load_store.mlir Fixes function name inconsistency
Comments suppressed due to low confidence (1)

mlir/test/Dialect/Rock/lowering_global_load_store.mlir:1

  • Removed trailing whitespace at the end of line 24.
// RUN: rocmlir-opt --rock-sugar-to-loops %s | FileCheck %s

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

where MAPS is `[#transform_mapM, #transform_mapN]`,
L is the length of the upper view last dimension (number of elements).

The input to extraViews ; (the transforms on %source) must have the form
Copy link

Copilot AI Oct 30, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Incorrect punctuation: semicolon should be removed or replaced with a comma or colon.

Suggested change
The input to extraViews ; (the transforms on %source) must have the form
The input to extraViews: (the transforms on %source) must have the form

Copilot uses AI. Check for mistakes.
Copy link
Contributor

@pabloantoniom pabloantoniom left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

General question: AFAIK, prefetch was available before gfx1250 (prefetching to L2), but this PR introduces rock ops for prefetching, so I assume we didnt have prefetch support before? If that's the case maybe we can open a ticket to investigate prefetch on other archs as well.

// CHECK-SAME: (%[[mem:.*]]: memref<f32>, %[[idx:.*]]: index)
func.func @load_scalar_empty_mem(%mem: memref<f32>, %idx: index) -> f32 {
// CHECK-SAME: (%[[mem:.*]]: memref<f32>)
func.func @load_scalar(%mem: memref<f32>) -> f32 {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This change does not seem related to this PR

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's not, but it's a small fix, do we want an independent PR for this?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just wanted to double check you wanted do commit that. I don't like too much having unrelated changes but I'm ok

def Rock_GlobalPrefetchOp
: Rock_Op<"global_prefetch">,
Arguments<(
ins Arg<MemRefOf<SupportedMemoryElems>, "source memory">:$source,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we want to limit the scope of $source to be #gpu.address_space<global>?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we need to do that in the verifier if we want to

: Rock_Op<"global_prefetch">,
Arguments<(
ins Arg<MemRefOf<SupportedMemoryElems>, "source memory">:$source,
Variadic<Index>:$sourceCoord)> {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why sourceCoord and not just offset?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

to keep it consistent with global_load

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I can see we have a lot of sourceCoord in the codebase, not a fan, but fair to leave it like this for global_prefetch


Note that it's ok if `sourceCoord` are out of bounds because we use Speculative Prefetch.
}];
let assemblyFormat = [{
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Happy to see this 😄


template <typename Load>
static LogicalResult verifyGlobalLoad(Load op) {
if (failed(verifyGlobalLoadAndPrefetch(op)))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure I understand, why if we want to verify a global load (i.e., inside verifyGlobalLoad) we also need to verify like if it was a prefetch (i.e., calling verifyGlobalLoadAndPrefetch)?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

global_load and global_prefetch had a common verificatin code. So, I created a function that applied to both (verifyGlobalLoadAndPrefetch)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, I see the idea behind it now. But at first it was confusing for me because I though that verifyGlobalLoadAndPrefetch was verifying performing checks for prefetch ops so it didn't make sense we were calling it from verifyGlobalLoad.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, it might be confusing, I'll try to think for a better name of the function or just add a comment to clarify

def Rock_ThreadwisePrefetchOp
: Rock_Op<"threadwise_prefetch", [DeclareOpInterfaceMethods<
RockAcceptingViewOpInterface>]>,
Arguments<(ins Arg<MemRefOf<SupportedMemoryElems>, "source view">:$source,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here we accept MemRefOf but verifier fails if gpuSrcMemSpaceAttr.getValue() != gpu::AddressSpace::Global. Don't we have a mechanism in tablegen to check for this?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not that I know of, but happy to add it if you know how

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right, I was going to suggest we can use something like:

def ROCDLGlobalBuffer : LLVM_PointerInAddressSpace<1>;
def ROCDLBufferLDS : LLVM_PointerInAddressSpace<3>;

but that is for LLVM pointers not for MLIR memrefs...there's nothing upstream for that yet.

arith::AddIOp::create(b, loc, indicesNext[0], one).getResult();

// it's ok if the indices are out of bounds because we use
// GLOBAL_PREFETCH_B8 with Speculative Prefetch
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you elaborate more on this comment about how are we so sure that the prefetch will be speculative?

This will be lowered to rock.global_prefetch and then to memref.prefetch, but its not obvious to me how memref.prefetch lowers to the GLOBAL_PREFETCH_B8 and how are we controling that its speculative prefetch (e.g., how does it control the temporal hint (TH) field.

Also related to this, it could be useful to add to the description of the PR a link to the upstream changes that implemented the lowering to GLOBAL_PREFETCH_B8, maybe all of this is explained there?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you can read in the documentation: llvm/docs/AMDGPUUsage.rst

  :ref:`llvm.prefetch <int_prefetch>`              Implemented on gfx1250, ignored on earlier targets.
                                                   First argument is flat, global, or constant address space pointer.
                                                   Any other address space is not supported.
                                                   On gfx125x generates flat_prefetch_b8 or global_prefetch_b8 and brings data to GL2.
                                                   Second argument is rw and currently ignored. Can be 0 or 1.
                                                   Third argument is locality, 0-3. Translates to memory scope:

                                                   * 0 - SCOPE_SYS
                                                   * 1 - SCOPE_DEV
                                                   * 2 - SCOPE_SE
                                                   * 3 - SCOPE_SE

                                                   Note that SCOPE_CU is not generated and not safe on an invalid address.
                                                   Fourth argument is cache type:

                                                   * 0 - Instruction cache, currently ignored and no code is generated.
                                                   * 1 - Data cache.

                                                   Instruction cache prefetches are unsafe on invalid address.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I understand from this, that as long as we use data prefetch, it's safe to use invalid addresses.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

but its not obvious to me how memref.prefetch lowers to the GLOBAL_PREFETCH_B8

we have a test for this, that the assembly generates GLOBAL_PREFETCH_B8. But it's disabled until gfx1250 compilation is fixed.

it could be useful to add to the description of the PR a link to the upstream changes that implemented the lowering to GLOBAL_PREFETCH_B8

sure, I can do that.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

here's the PR: llvm/llvm-project#150493

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So as far as I understand ThreadwisePrefetchOp is not actually deciding the locality, that is decided later in memref::PrefetchOp. Would it make sense to have an attribute in ThreadwisePrefetchOp to specify the locality and pass it to PrefetchOp, instead of passing a 3 when generating the PrefetchOp ?

Also, it's still unclear to me how this will generate a speculative prefetch. I mean, what field / attribute is controlling that? I think we should add a comment to make this more explicit.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok I saw the answer to the 2nd question in the pdf, I think it would be good to have a comment mentioning this


source = asGlobal(b, source);
b.replaceOpWithNewOp<memref::PrefetchOp>(
op, source, op.getSourceCoord(), /*isWrite=*/false, /*localityHint=*/3,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is localityHint controling where are we prefetching, e.g., WGP or L2?

Also, isn't there a key in memref that we can use instead of hardcoding a 3? Would be good to either add a comment or have some enum to make explicit what this 3 means...

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can add a comment

@dhernandez0
Copy link
Contributor Author

General question: AFAIK, prefetch was available before gfx1250 (prefetching to L2), but this PR introduces rock ops for prefetching, so I assume we didnt have prefetch support before? If that's the case maybe we can open a ticket to investigate prefetch on other archs as well.

I don't think there was any prefetch from global memory to L2 for previous architectures.

@dhernandez0 dhernandez0 force-pushed the 2067-use-prefetch-instructions branch from a0f5d92 to cb48e56 Compare October 31, 2025 10:38
@pabloantoniom
Copy link
Contributor

LGTM, I think we just need to complete the TODO in mlir/test/rocmlir-driver/prefetch_assembly.mlir

@dhernandez0 dhernandez0 force-pushed the 2067-use-prefetch-instructions branch from 7ee84c2 to be1d8cc Compare November 3, 2025 10:37
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants