Skip to content

Commit d32e8dd

Browse files
Resolve conflicts
1 parent 919df17 commit d32e8dd

File tree

5 files changed

+17
-20
lines changed

5 files changed

+17
-20
lines changed

include/cutlass/gemm/collective/xe_array_mma_mixed_input.hpp

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -173,23 +173,23 @@ struct CollectiveMma<
173173

174174
using MmaAtomShape = typename TiledMma::AtomShape_MNK;
175175

176-
static constexpr auto BLK_M = get<0>(WorkgroupTileShape{});
177-
static constexpr auto BLK_N = get<1>(WorkgroupTileShape{});
178-
static constexpr auto BLK_K = get<2>(WorkgroupTileShape{});
176+
static constexpr int BLK_M = get<0>(WorkgroupTileShape{});
177+
static constexpr int BLK_N = get<1>(WorkgroupTileShape{});
178+
static constexpr int BLK_K = get<2>(WorkgroupTileShape{});
179179

180-
static constexpr auto ATOM_M = get<1>(typename TiledMma::ThrLayoutVMNK{}.shape());
181-
static constexpr auto ATOM_N = get<2>(typename TiledMma::ThrLayoutVMNK{}.shape());
182-
static constexpr auto ATOM_K = get<3>(typename TiledMma::ThrLayoutVMNK{}.shape());
180+
static constexpr int ATOM_M = get<1>(typename TiledMma::ThrLayoutVMNK{}.shape());
181+
static constexpr int ATOM_N = get<2>(typename TiledMma::ThrLayoutVMNK{}.shape());
182+
static constexpr int ATOM_K = get<3>(typename TiledMma::ThrLayoutVMNK{}.shape());
183183

184-
static constexpr auto SG_M = ceil_div(BLK_M, ATOM_M);
185-
static constexpr auto SG_N = ceil_div(BLK_N, ATOM_N);
186-
static constexpr auto SG_K = ceil_div(BLK_K, ATOM_K);
184+
static constexpr int SG_M = ceil_div(BLK_M, ATOM_M);
185+
static constexpr int SG_N = ceil_div(BLK_N, ATOM_N);
186+
static constexpr int SG_K = ceil_div(BLK_K, ATOM_K);
187187
using SubgroupTileShape = Shape<decltype(SG_M), decltype(SG_N), decltype(SG_K)>;
188188

189189
using GmemTiledCopyScale = typename scale_zero_copy_traits<NonVoidElementScale, SG_N>::type;
190190
using GmemTiledCopyZero = typename scale_zero_copy_traits<NonVoidElementZero, SG_N, InternalNonVoidStrideZero>::type;
191191

192-
static constexpr auto Num_SGs = ATOM_N * ATOM_M * ATOM_K;
192+
static constexpr int Num_SGs = ATOM_N * ATOM_M * ATOM_K;
193193
static constexpr uint32_t MaxThreadsPerBlock = size(TiledMma{});
194194

195195
using CopyThreadShape = Shape<_1, Int<SubgroupSize>>;

test/unit/cute/turing/movm.cu

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,8 +41,8 @@
4141
using namespace cute;
4242

4343
#ifdef CUTLASS_ENABLE_SYCL
44-
namespace sc = syclcompat;
45-
namespace sc_exp = syclcompat::experimental;
44+
namespace sc = cutlasscompat;
45+
namespace sc_exp = cutlasscompat::experimental;
4646
namespace sycl_ext = sycl::ext::oneapi::experimental;
4747
#endif
4848

tools/util/include/cutlass/util/initialize_block.hpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434

3535
#pragma once
3636
#ifdef CUTLASS_ENABLE_SYCL
37-
#include <syclcompat.hpp>
37+
#include <cutlasscompat.hpp>
3838
#else
3939
#include <cuda.h>
4040
#endif
@@ -111,7 +111,7 @@ bool initialize_block(Element* block, std::size_t size, uint64_t seed, Args_t&&.
111111
}
112112
}
113113

114-
syclcompat::wait();
114+
cutlasscompat::wait();
115115
return true;
116116
}
117117

@@ -189,7 +189,7 @@ void initialize_mixed_dtype_block(cutlass::DeviceAllocation<T1>& block_device,
189189
}
190190
}
191191

192-
syclcompat::wait();
192+
cutlasscompat::wait();
193193
}
194194

195195
#undef CUDA_CHECK

tools/util/include/cutlass/util/reference/device/sycl_tensor_fill.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -189,8 +189,8 @@ void BlockFillRandomUniformCopyFromHost(
189189
for (size_t i = 0; i < capacity; ++i) {
190190
buff[i] = (Element)(dis(gen));
191191
}
192-
syclcompat::memcpy<Element>(ptr, buff.data(), capacity);
193-
syclcompat::wait();
192+
cutlasscompat::memcpy<Element>(ptr, buff.data(), capacity);
193+
cutlasscompat::wait();
194194
} else {
195195
assert(false && "Not supported dtype");
196196
}

tools/util/include/cutlasscompat/launch.hpp

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,6 @@
3333
#include <cutlasscompat/dims.hpp>
3434
#include <cutlasscompat/launch_policy.hpp>
3535

36-
#include <cute/container/type_list.hpp>
37-
3836
namespace cutlasscompat {
3937

4038
namespace detail {
@@ -154,7 +152,6 @@ sycl::event launch(LaunchPolicy launch_policy, sycl::queue q, Args... args) {
154152
template <auto F, class N=sycl::detail::auto_name, typename LaunchPolicy, typename... Args>
155153
sycl::event launch(LaunchPolicy launch_policy, sycl::queue q, Args... args) {
156154
static_assert(detail::is_launch_policy_v<LaunchPolicy>);
157-
//using FN = std::conditional_t<std::is_same_v<N, sycl::detail::auto_name>, sycl::detail::auto_name, cute::type_list<decltype(F), N>>;
158155
return detail::launch<F, N>(launch_policy, q, args...);
159156
}
160157

0 commit comments

Comments
 (0)