Skip to content

Conversation

anamikac-intel
Copy link

@anamikac-intel anamikac-intel commented Sep 29, 2025

Modify 00_bmg_gemm to include new mma and copy atoms (#477).
00_bmg_gemm combines two parts: mma and epilogue. To add new atom changes, we need to update both parts since they currently use old atoms. As starting we will:

Keep CollectiveEpilogue unchanged for now
Only modify CollectiveMma first

@anamikac-intel anamikac-intel marked this pull request as ready for review September 29, 2025 08:11
@anamikac-intel anamikac-intel changed the title Use newer version on mma_atom and copy_atom in 00_bmg_gemm Use newer version of mma_atom and copy_atom in 00_bmg_gemm Sep 30, 2025
Comment on lines 179 to 192
auto copy_a = [&]() {
if constexpr (!std::is_void_v<GmemTiledCopyA>) {
// User provided copy operation - use full stride
auto mA_mkl = make_tensor(make_gmem_ptr(mainloop.ptr_A),
make_layout(make_shape(mainloop.M, mainloop.K, mainloop.L), mainloop.dA));
using Copy_A = typename Copy_Traits<GmemTiledCopyA, StrideA>::template DefaultTiledCopy<ElementA>;
return Copy_A{}.with(mA_mkl);
} else {
// Use new 2D copy operations with 2D stride
auto mA_mkl = make_tensor(make_gmem_ptr(mainloop.ptr_A),
make_layout(make_shape(mainloop.M, mainloop.K, mainloop.L), cute::take<0,2>(mainloop.dA)));
return make_block_2d_copy_A(TiledMma{}, mA_mkl);
}
}();
Copy link

@petercad petercad Oct 3, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A few comments:

  • mA_mkl should be the same on both branches, representing the same global memory tensor. You can use the definition on line 188 unconditionally, and pull it out of the if constexpr.
  • take<B,E> means take modes B,B+1,...,E-1, so it should be take<0,3> here.
  • DefaultTiledCopy is a hack from the legacy copy atoms -- new copy atoms don't have this, and it won't work correctly even with legacy atoms because they were tiled at subgroup scope. Instead, you can pass the user's op to make_block_2d_copy_*.
Suggested change
auto copy_a = [&]() {
if constexpr (!std::is_void_v<GmemTiledCopyA>) {
// User provided copy operation - use full stride
auto mA_mkl = make_tensor(make_gmem_ptr(mainloop.ptr_A),
make_layout(make_shape(mainloop.M, mainloop.K, mainloop.L), mainloop.dA));
using Copy_A = typename Copy_Traits<GmemTiledCopyA, StrideA>::template DefaultTiledCopy<ElementA>;
return Copy_A{}.with(mA_mkl);
} else {
// Use new 2D copy operations with 2D stride
auto mA_mkl = make_tensor(make_gmem_ptr(mainloop.ptr_A),
make_layout(make_shape(mainloop.M, mainloop.K, mainloop.L), cute::take<0,2>(mainloop.dA)));
return make_block_2d_copy_A(TiledMma{}, mA_mkl);
}
}();
auto mA_mkl = make_tensor(make_gmem_ptr(mainloop.ptr_A),
make_layout(make_shape(mainloop.M, mainloop.K, mainloop.L), cute::take<0,3>(mainloop.dA)));
auto copy_a = [&] {
if constexpr (!std::is_void_v<GmemTiledCopyA>) {
// User provided copy operation - use full stride
return make_block_2d_copy_A(GmemTiledCopyA{}, TiledMma{}, mA_mkl);
} else {
return make_block_2d_copy_A(TiledMma{}, mA_mkl);
}
}();

Copy link
Author

@anamikac-intel anamikac-intel Oct 4, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@petercad Few Observation with change suggested :

  1. Stride Dimension Mismatch Issue:
    The stride values for matrices A and B are (4096, -1, 0) and (-1, 4096, 0) respectively.
    image

    So when attempting to use all dimensions , it fails in make_block_2d_copy_* at:
    **/include/cute/atom/copy_traits_xe_2d.hpp:503:18
    auto x_shape = elem_scale(ShapeTiler_MN{}, atom_shape);

    The failure occurs in the elem_scale function due to a tuple size mismatch:
    Tuple 1: cute::tuple<cute::C<32>, cute::C<32>, cute::C<1>> (3 elements)
    Tuple 2: cute::tuple<cute::C<8>, cute::C<1>> (2 elements)
    The transform function requires both tuples to have identical sizes, but the mismatch (3 ≠ 2) causes the compilation error.**
    image

  2. Missing Copy Atom Members:
    The make_block_2d_copy_A(GmemTiledCopyA{}, TiledMma{}, mA_mkl) approach is also failing because the old copy atom structures cute::XE_2D_U16x32x32_LD_N and cute::XE_2D_U16x32x32_LD_V are missing below three member variables:

       1. AtomWidth 
       2. AtomHeight 
       3. CopyBits 
    

    **The make_block_2d_copy_X function expects these members to be present in the copy operation classes, as shown in the compilation errors at:

    /include/cute/atom/copy_traits_xe_2d.hpp:740:33
    constexpr int Width = CopyOp::AtomWidth * CopyOp::CopyBits / sizeof_bits_v;
    /include/cute/atom/copy_traits_xe_2d.hpp:741:34
    constexpr int Height = CopyOp::AtomHeight;**

    image image

    Even after defining those member variables in old atom I am landing on missing type definitions and template parameter mismatch errors.. seems old atoms not compatible with make_block_2d_copy_*

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

On 1 -- let me look into it, may be a bug in make_block_2d_copy_*

On 2 -- we don't need to support the old atoms (as discussed offline). Compile errors here are expected if you tried to use the old atoms (maybe we can add some code to make_block_2d_copy_* to catch usages of the old ops and provide a nicer compile-time error).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants