Skip to content
This repository was archived by the owner on Aug 30, 2024. It is now read-only.

Commit 5456fc0

Browse files
committed
sync ipex 20240618
1 parent 957c5a4 commit 5456fc0

File tree

7 files changed

+81
-11
lines changed

7 files changed

+81
-11
lines changed
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
/*******************************************************************************
2+
* Copyright (c) 2022-2023 Intel Corporation
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*******************************************************************************/
16+
17+
/// @file
18+
/// C++ API
19+
20+
#pragma once
21+
22+
namespace gpu::xetla {
23+
24+
/// @brief xetla 4bits data packed as 8bits data type.
25+
/// 2 4bit data pack to one byte
26+
struct int4x2 {
27+
uint8_t data;
28+
29+
operator uint8_t() const {
30+
return data;
31+
}
32+
int4x2(uint8_t val) {
33+
data = val;
34+
}
35+
};
36+
37+
/// @brief Used to check if the type is xetla internal data type
38+
template <>
39+
struct is_internal_type<int4x2> {
40+
static constexpr bool value = true;
41+
};
42+
43+
/// @brief Set uint8_t as the native data type of int4x2.
44+
template <>
45+
struct native_type<int4x2> {
46+
using type = uint8_t;
47+
};
48+
49+
} // namespace gpu::xetla
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
/*******************************************************************************
2+
* Copyright (c) 2022-2023 Intel Corporation
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*******************************************************************************/
16+
17+
/// @file
18+
/// C++ API
19+
20+
#pragma once
21+
22+
#include <common/common.hpp>
23+
#include <experimental/common/base_types.hpp>

include/experimental/kernel/gemm/impl/int4_dequantize_kslicing_xe.hpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -558,7 +558,8 @@ class gemm_universal_t<
558558
// args.matrix_m,
559559
// args.matC_ld);
560560
// } else {
561-
// implementable &= kernel::general_1d<arch_tag, dtype_c>::check_alignment(
561+
// implementable &= kernel::general_1d<arch_tag,
562+
// dtype_c>::check_alignment(
562563
// args.matC_base.base, args.matC_ld);
563564
// }
564565
// }

include/experimental/kernel/layer_norm/layer_norm_fwd_xe.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -326,7 +326,7 @@ struct layer_norm_fwd_t<
326326
itr_count += 1;
327327
nbarrier.wait();
328328

329-
xetla_vector<dtype_acc, wg_size_x * 2> mu_m2_vec =
329+
xetla_vector<dtype_acc, wg_size_x* 2> mu_m2_vec =
330330
xetla_load_local<dtype_acc, wg_size_x * 2>(slm_load_base);
331331
xetla_vector<dtype_acc, wg_size_x> mu_vec =
332332
mu_m2_vec.xetla_select<wg_size_x, 2>(0);

include/kernel/gemm/impl/default_xe.hpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -283,7 +283,8 @@ class gemm_universal_t<
283283
// args.matrix_m,
284284
// args.matC_ld);
285285
// } else {
286-
// implementable &= kernel::general_1d<arch_tag, dtype_c>::check_alignment(
286+
// implementable &= kernel::general_1d<arch_tag,
287+
// dtype_c>::check_alignment(
287288
// args.matC_base.base, args.matC_ld);
288289
// }
289290
// }

include/kernel/gemm/impl/stream_k_xe.hpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -337,7 +337,8 @@ class gemm_universal_t<
337337
// args.matrix_m,
338338
// args.matC_ld);
339339
// } else {
340-
// implementable &= kernel::general_1d<arch_tag, dtype_c>::check_alignment(
340+
// implementable &= kernel::general_1d<arch_tag,
341+
// dtype_c>::check_alignment(
341342
// args.matC_base.base, args.matC_ld);
342343
// }
343344
// }

include/subgroup/tile/impl/payload_xe.hpp

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ struct mem_payload_t<
8686
xetla_vector<uint32_t, 16 * num_block> payloads;
8787

8888
inline mem_payload_t(const this_payload_t& rhs) {
89-
this->payload = rhs.payload;
89+
this->payloads = rhs.payloads;
9090
}
9191

9292
inline mem_payload_t(mem_desc_t& mem_desc) {
@@ -159,7 +159,7 @@ struct mem_payload_t<
159159
// ~mem_payload_t(){}
160160

161161
inline this_payload_t& operator=(const this_payload_t& rhs) {
162-
this->payload = rhs.payload;
162+
this->payloads = rhs.payloads;
163163
return *this;
164164
}
165165

@@ -1739,9 +1739,6 @@ struct prefetch_payload_t<
17391739
this->width_in_elems = rhs.width_in_elems;
17401740
this->height_in_elems = rhs.height_in_elems;
17411741

1742-
this->step_x = rhs.step_x;
1743-
this->step_y = rhs.step_y;
1744-
17451742
this->channel_offset = rhs.channel_offset;
17461743
}
17471744

@@ -1756,8 +1753,6 @@ struct prefetch_payload_t<
17561753
this->width_in_elems = rhs.width_in_elems;
17571754
this->height_in_elems = rhs.height_in_elems;
17581755

1759-
this->step_x = rhs.step_x;
1760-
this->step_y = rhs.step_y;
17611756
this->channel_offset = rhs.channel_offset;
17621757
return *this;
17631758
}

0 commit comments

Comments
 (0)