-
Notifications
You must be signed in to change notification settings - Fork 57
Use newer version of mma_atom and copy_atom in 00_bmg_gemm #540
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
auto copy_a = [&]() { | ||
if constexpr (!std::is_void_v<GmemTiledCopyA>) { | ||
// User provided copy operation - use full stride | ||
auto mA_mkl = make_tensor(make_gmem_ptr(mainloop.ptr_A), | ||
make_layout(make_shape(mainloop.M, mainloop.K, mainloop.L), mainloop.dA)); | ||
using Copy_A = typename Copy_Traits<GmemTiledCopyA, StrideA>::template DefaultTiledCopy<ElementA>; | ||
return Copy_A{}.with(mA_mkl); | ||
} else { | ||
// Use new 2D copy operations with 2D stride | ||
auto mA_mkl = make_tensor(make_gmem_ptr(mainloop.ptr_A), | ||
make_layout(make_shape(mainloop.M, mainloop.K, mainloop.L), cute::take<0,2>(mainloop.dA))); | ||
return make_block_2d_copy_A(TiledMma{}, mA_mkl); | ||
} | ||
}(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A few comments:
mA_mkl
should be the same on both branches, representing the same global memory tensor. You can use the definition on line 188 unconditionally, and pull it out of theif constexpr
.take<B,E>
means take modes B,B+1,...,E-1, so it should be take<0,3> here.DefaultTiledCopy
is a hack from the legacy copy atoms -- new copy atoms don't have this, and it won't work correctly even with legacy atoms because they were tiled at subgroup scope. Instead, you can pass the user's op tomake_block_2d_copy_*
.
auto copy_a = [&]() { | |
if constexpr (!std::is_void_v<GmemTiledCopyA>) { | |
// User provided copy operation - use full stride | |
auto mA_mkl = make_tensor(make_gmem_ptr(mainloop.ptr_A), | |
make_layout(make_shape(mainloop.M, mainloop.K, mainloop.L), mainloop.dA)); | |
using Copy_A = typename Copy_Traits<GmemTiledCopyA, StrideA>::template DefaultTiledCopy<ElementA>; | |
return Copy_A{}.with(mA_mkl); | |
} else { | |
// Use new 2D copy operations with 2D stride | |
auto mA_mkl = make_tensor(make_gmem_ptr(mainloop.ptr_A), | |
make_layout(make_shape(mainloop.M, mainloop.K, mainloop.L), cute::take<0,2>(mainloop.dA))); | |
return make_block_2d_copy_A(TiledMma{}, mA_mkl); | |
} | |
}(); | |
auto mA_mkl = make_tensor(make_gmem_ptr(mainloop.ptr_A), | |
make_layout(make_shape(mainloop.M, mainloop.K, mainloop.L), cute::take<0,3>(mainloop.dA))); | |
auto copy_a = [&] { | |
if constexpr (!std::is_void_v<GmemTiledCopyA>) { | |
// User provided copy operation - use full stride | |
return make_block_2d_copy_A(GmemTiledCopyA{}, TiledMma{}, mA_mkl); | |
} else { | |
return make_block_2d_copy_A(TiledMma{}, mA_mkl); | |
} | |
}(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@petercad Few Observation with change suggested :
-
Stride Dimension Mismatch Issue:
The stride values for matrices A and B are (4096, -1, 0) and (-1, 4096, 0) respectively.
So when attempting to use all dimensions , it fails in make_block_2d_copy_* at:
**/include/cute/atom/copy_traits_xe_2d.hpp:503:18
auto x_shape = elem_scale(ShapeTiler_MN{}, atom_shape);The failure occurs in the elem_scale function due to a tuple size mismatch:
Tuple 1: cute::tuple<cute::C<32>, cute::C<32>, cute::C<1>> (3 elements)
Tuple 2: cute::tuple<cute::C<8>, cute::C<1>> (2 elements)
The transform function requires both tuples to have identical sizes, but the mismatch (3 ≠ 2) causes the compilation error.**
-
Missing Copy Atom Members:
The make_block_2d_copy_A(GmemTiledCopyA{}, TiledMma{}, mA_mkl) approach is also failing because the old copy atom structures cute::XE_2D_U16x32x32_LD_N and cute::XE_2D_U16x32x32_LD_V are missing below three member variables:1. AtomWidth 2. AtomHeight 3. CopyBits
**The make_block_2d_copy_X function expects these members to be present in the copy operation classes, as shown in the compilation errors at:
/include/cute/atom/copy_traits_xe_2d.hpp:740:33
constexpr int Width = CopyOp::AtomWidth * CopyOp::CopyBits / sizeof_bits_v;
/include/cute/atom/copy_traits_xe_2d.hpp:741:34
constexpr int Height = CopyOp::AtomHeight;**Even after defining those member variables in old atom I am landing on missing type definitions and template parameter mismatch errors.. seems old atoms not compatible with make_block_2d_copy_*
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
On 1 -- let me look into it, may be a bug in make_block_2d_copy_*
On 2 -- we don't need to support the old atoms (as discussed offline). Compile errors here are expected if you tried to use the old atoms (maybe we can add some code to make_block_2d_copy_* to catch usages of the old ops and provide a nicer compile-time error).
Modify 00_bmg_gemm to include new mma and copy atoms (#477).
00_bmg_gemm combines two parts: mma and epilogue. To add new atom changes, we need to update both parts since they currently use old atoms. As starting we will: