Skip to content
This repository was archived by the owner on Aug 30, 2024. It is now read-only.

Commit a9d5cc9

Browse files
Shuffle kernel (#235)
* todo: update client ut * tmp * shuf ut * todo: fix shuf bug when m=8 * fix * typo & ordered queue * fix comment
1 parent c4e8540 commit a9d5cc9

File tree

11 files changed

+779
-49
lines changed

11 files changed

+779
-49
lines changed
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
/*******************************************************************************
2+
* Copyright (c) 2023-2024 Intel Corporation
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*******************************************************************************/
16+
17+
/// @file
18+
/// C++ API
19+
20+
#pragma once
21+
22+
#include <experimental/kernel/col_major_shuf/common.hpp>
23+
#include <experimental/kernel/col_major_shuf/config.hpp>
24+
25+
namespace gpu::xetla::kernel {
26+
27+
/// @brief
28+
///
29+
/// @tparam dtype_in_ input data type.
30+
/// @tparam dtype_out_ output data type.
31+
/// @tparam dtype_gidx_ gidx data type.
32+
/// @tparam mem_layout_in_ input memory layout.
33+
/// @tparam col_major_shuf_attr_ parallel-related attributes.
34+
/// @tparam arch_ HW architecture.
35+
template <
36+
typename dtype_in_,
37+
typename dtype_out_,
38+
typename dtype_gidx_,
39+
mem_layout mem_layout_in_,
40+
typename col_major_shuf_attr_,
41+
gpu_arch arch_,
42+
typename enable = void>
43+
struct col_major_shuf_t {};
44+
45+
} // namespace gpu::xetla::kernel
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
/*******************************************************************************
2+
* Copyright (c) 2023-2024 Intel Corporation
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*******************************************************************************/
16+
17+
/// @file
18+
/// C++ API
19+
20+
#pragma once
21+
22+
#include <experimental/kernel/col_major_shuf/api.hpp>
23+
#include <experimental/kernel/col_major_shuf/common.hpp>
24+
#include <experimental/kernel/col_major_shuf/config.hpp>
25+
#include <experimental/kernel/col_major_shuf/col_major_shuf_xe.hpp>
Lines changed: 161 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,161 @@
1+
/*******************************************************************************
2+
* Copyright (c) 2023-2024 Intel Corporation
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*******************************************************************************/
16+
17+
/// @file
18+
/// C++ API
19+
20+
#pragma once
21+
22+
#include <experimental/kernel/col_major_shuf/api.hpp>
23+
#include <experimental/kernel/col_major_shuf/common.hpp>
24+
#include <experimental/kernel/col_major_shuf/config.hpp>
25+
26+
namespace gpu::xetla::kernel {
27+
template <
28+
typename dtype_in_,
29+
typename dtype_out_,
30+
typename dtype_gidx_,
31+
mem_layout mem_layout_in_,
32+
typename col_major_shuf_attr_,
33+
gpu_arch arch_>
34+
struct col_major_shuf_t<
35+
dtype_in_,
36+
dtype_out_,
37+
dtype_gidx_,
38+
mem_layout_in_,
39+
col_major_shuf_attr_,
40+
arch_> {
41+
using dtype_in = dtype_in_;
42+
using dtype_out = dtype_out_;
43+
using dtype_gidx = dtype_gidx_;
44+
using col_major_shuf_attr = col_major_shuf_attr_;
45+
46+
static constexpr mem_layout mem_layout_in = mem_layout_in_;
47+
48+
static_assert(
49+
std::is_same<dtype_in, dtype_out>::value,
50+
"only support in/out data type must be same now.");
51+
static_assert(
52+
mem_layout_in == mem_layout::row_major,
53+
"only support row_major input now.");
54+
static_assert(
55+
std::is_same<dtype_gidx, uint32_t>::value,
56+
"dtype_gidx must be uint32_t");
57+
58+
static constexpr uint32_t wg_tile_x = col_major_shuf_attr::wg_tile_x;
59+
static constexpr uint32_t wg_tile_y = col_major_shuf_attr::wg_tile_y;
60+
static constexpr uint32_t sg_tile_x = col_major_shuf_attr::sg_tile_x;
61+
static constexpr uint32_t sg_tile_y = col_major_shuf_attr::sg_tile_y;
62+
63+
static constexpr uint32_t tile_size_x = sg_tile_x;
64+
static constexpr uint32_t tile_size_y = sg_tile_y;
65+
66+
static constexpr uint32_t block_size_x =
67+
col_major_shuf_attr::load_block_size; // TODO(zhe:) add load block size
68+
// check under different arch
69+
70+
static constexpr uint32_t dev_mem_align = 64;
71+
using mem_desc_store_tile_t = mem_desc_t<
72+
dtype_in,
73+
mem_layout_in,
74+
mem_space::global,
75+
dev_mem_align / sizeof(dtype_in)>;
76+
using store_tile_desc_t = subgroup::tile_desc_t<
77+
tile_size_x,
78+
tile_size_y,
79+
block_size_x,
80+
tile_size_y,
81+
reg_layout::tiled>;
82+
using store_tile_t = subgroup::tile_t<dtype_out, store_tile_desc_t>;
83+
using store_tile_payload_t = subgroup::mem_payload_t<
84+
mem_desc_store_tile_t,
85+
store_tile_desc_t,
86+
subgroup::msg_type_v<store_tile_desc_t, mem_space::global>,
87+
arch_>;
88+
89+
using mem_desc_gidx_t = mem_desc_t<
90+
dtype_gidx,
91+
mem_layout::row_major,
92+
mem_space::global,
93+
dev_mem_align / sizeof(dtype_gidx)>;
94+
using gidx_tile_desc_t =
95+
subgroup::tile_desc_t<tile_size_x, 1, block_size_x, 1, reg_layout::tiled>;
96+
using gidx_t = subgroup::tile_t<dtype_gidx, gidx_tile_desc_t>;
97+
using gidx_payload_t = subgroup::mem_payload_t<
98+
mem_desc_gidx_t,
99+
gidx_tile_desc_t,
100+
subgroup::msg_type_v<gidx_tile_desc_t, mem_space::global>,
101+
arch_>;
102+
103+
struct arguments_t {
104+
dtype_in* mat_in_ptr;
105+
dtype_out* mat_out_ptr;
106+
dtype_gidx* gidx_ptr;
107+
uint32_t matrix_x;
108+
uint32_t matrix_y;
109+
};
110+
111+
__XETLA_API static void call(sycl::nd_item<3>& item, arguments_t& args) {
112+
int gid_x = item.get_group(2);
113+
int gid_y = item.get_group(1);
114+
int x_dim_offset = gid_x * wg_tile_x;
115+
int y_dim_offset = gid_y * wg_tile_y;
116+
int tid_x = item.get_local_id(2);
117+
int tid_y = item.get_local_id(1);
118+
x_dim_offset += tid_x * sg_tile_x;
119+
y_dim_offset += tid_y * sg_tile_y;
120+
mem_desc_gidx_t gidx_desc(
121+
args.gidx_ptr, {args.matrix_x, 1, args.matrix_x}, {x_dim_offset, 0});
122+
mem_desc_store_tile_t store_tile_desc(
123+
args.mat_out_ptr,
124+
{args.matrix_x, args.matrix_y, args.matrix_x},
125+
{x_dim_offset, y_dim_offset});
126+
127+
static constexpr int block_x_num = tile_size_x / block_size_x;
128+
static constexpr int elt_per_block = block_size_x * tile_size_y;
129+
store_tile_t store_tile;
130+
store_tile_payload_t store_tile_payload(store_tile_desc);
131+
gidx_payload_t gidx_payload(gidx_desc);
132+
133+
#pragma unroll
134+
for (int block_x = 0; block_x < block_x_num; block_x++) {
135+
auto gidx = xetla_load_global<
136+
uint32_t,
137+
block_size_x,
138+
data_size::default_size,
139+
cache_hint::cached,
140+
cache_hint::cached>(
141+
args.gidx_ptr, gidx_payload.base_offset + block_x * block_size_x);
142+
#pragma unroll
143+
for (uint32_t row = 0; row < tile_size_y; row++) {
144+
store_tile.reg.xetla_select<block_size_x, 1>(
145+
block_x * elt_per_block + row * block_size_x) =
146+
xetla_load_global<
147+
dtype_in,
148+
1,
149+
data_size::default_size,
150+
cache_hint::cached,
151+
cache_hint::cached,
152+
block_size_x>(
153+
args.mat_in_ptr + (y_dim_offset + row) * args.matrix_x,
154+
gidx,
155+
1);
156+
}
157+
}
158+
tile_store(store_tile, store_tile_payload);
159+
};
160+
};
161+
} // namespace gpu::xetla::kernel
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
/*******************************************************************************
2+
* Copyright (c) 2023-2024 Intel Corporation
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*******************************************************************************/
16+
17+
/// @file
18+
/// C++ API
19+
20+
#pragma once
21+
22+
#include <common/common.hpp>
23+
#include <group/group.hpp>
24+
#include <subgroup/subgroup.hpp>
25+
26+
namespace gpu::xetla {} // namespace gpu::xetla
Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
/*******************************************************************************
2+
* Copyright (c) 2023-2024 Intel Corporation
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*******************************************************************************/
16+
17+
/// @file
18+
/// C++ API
19+
20+
#pragma once
21+
22+
#include <experimental/kernel/layer_norm/common.hpp>
23+
24+
namespace gpu::xetla::kernel {
25+
26+
/// @brief Sets up attribute of the layer norm.
27+
///
28+
/// @tparam wg_tile_x_ Is the num of cols processed by one workgroup.
29+
/// @tparam wg_tile_y_ Is the num of rows processed by one workgroup.
30+
/// @tparam sg_tile_x_ Is the num of cols processed by one subgroup.
31+
/// @tparam sg_tile_y_ Is the num of rows processed by one subgroup.
32+
/// @tparam load_block_size_ Is the size of block when load x dimenstion.
33+
/// kernels have spills.
34+
template <
35+
uint32_t wg_tile_x_,
36+
uint32_t wg_tile_y_,
37+
uint32_t sg_tile_x_,
38+
uint32_t sg_tile_y_,
39+
uint32_t load_block_size_>
40+
struct col_major_shuf_attr_t {
41+
static constexpr uint32_t wg_tile_x = wg_tile_x_;
42+
static constexpr uint32_t wg_tile_y = wg_tile_y_;
43+
static constexpr uint32_t sg_tile_x = sg_tile_x_;
44+
static constexpr uint32_t sg_tile_y = sg_tile_y_;
45+
static constexpr uint32_t load_block_size = load_block_size_;
46+
47+
static_assert(
48+
wg_tile_x % sg_tile_x == 0,
49+
"Current design we don't enable the boundary check");
50+
static_assert(
51+
sg_tile_x % load_block_size == 0 && sg_tile_x >= load_block_size,
52+
"Current design we don't enable the boundary check on chunking "
53+
"mechanism");
54+
};
55+
56+
} // namespace gpu::xetla::kernel

include/experimental/kernel/kernel.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919

2020
#pragma once
2121

22+
#include <experimental/kernel/col_major_shuf/col_major_shuf.hpp>
2223
#include <experimental/kernel/data_transformer/data_transformer.hpp>
2324
#include <experimental/kernel/gemm/gemm.hpp>
2425
#include <experimental/kernel/layer_norm/layer_norm.hpp>

tests/integration/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,3 +29,4 @@ add_subdirectory(sg_dropout_op)
2929
add_subdirectory(limitation)
3030
add_subdirectory(softmax)
3131
add_subdirectory(fmha)
32+
add_subdirectory(col_major_shuf)
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
get_filename_component(ProjectId ${CMAKE_CURRENT_SOURCE_DIR} NAME)
2+
string(REPLACE " " "_" ProjectId ${ProjectId})
3+
4+
FILE(GLOB src main.cpp)
5+
add_integration_test(${ProjectId} "${src}")

0 commit comments

Comments
 (0)