Skip to content

Commit 866b8fb

Browse files
committed
Optimize layout for SubgroupMatrixLoad on Intel
This introduces a new LayoutProgram to pre-process the input matrix A, converting it to a layout that is more efficient for the SubgroupMatrixLoad operation on Intel GPUs.
1 parent aa644e8 commit 866b8fb

File tree

2 files changed

+102
-19
lines changed

2 files changed

+102
-19
lines changed

onnxruntime/contrib_ops/webgpu/quantization/subgroup_matrix_matmul_nbits.cc

Lines changed: 101 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -73,23 +73,76 @@ bool IsSubgroupMatrixConfigSupportedOnIntel(onnxruntime::webgpu::ComputeContext&
7373
return false;
7474
}
7575

76+
// This program optimizes the layout of input matrix A(MxK) for SubgroupMatrixLoad, so that all elements of each
77+
// subgroup matrix(mxk) are arranged continuously in memory.
78+
// Take "M = 4, K = 4, m = 2, k = 2" as an example, the input matrix A is arranged in row-major order as follows:
79+
// d00, d01, | d02, d03,
80+
// d10, d11, | d12, d13,
81+
// ---------------------
82+
// d20, d21, | d22, d23,
83+
// d30, d31, | d32, d33,
84+
//
85+
// The layout program rearranges the input matrix A to be in the following order:
86+
// d00, d01, d10, d11,
87+
// -------------------
88+
// d02, d03, d12, d13,
89+
// -------------------
90+
// d20, d21, d30, d31,
91+
// -------------------
92+
// d22, d23, d32, d33,
93+
class LayoutProgram final : public Program<LayoutProgram> {
94+
public:
95+
LayoutProgram(uint32_t m, uint32_t k, std::string_view component_type) : Program{"SubgroupMatrixMatMulLayout"},
96+
m_(m), k_(k), component_type_(component_type) {}
97+
Status GenerateShaderCode(ShaderHelper& sh) const override;
98+
WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES(
99+
{"M", ProgramUniformVariableDataType::Uint32},
100+
{"K", ProgramUniformVariableDataType::Uint32});
101+
private:
102+
uint32_t m_;
103+
uint32_t k_;
104+
std::string_view component_type_;
105+
};
106+
107+
Status LayoutProgram::GenerateShaderCode(ShaderHelper& shader) const {
108+
shader.AddInput("input_a", ShaderUsage::UseUniform);
109+
shader.AddOutput("output_a", ShaderUsage::UseUniform);
110+
shader.AdditionalImplementation() << "alias component_type = " << component_type_ << ";\n"
111+
<< "const m_dim: u32 = " << m_ << ";\n"
112+
<< "const k_dim: u32 = " << k_ << ";\n";
113+
114+
shader.MainFunctionBody() << R"MAIN_FN(
115+
let M = uniforms.M;
116+
let K = uniforms.K;
117+
let in_offset = workgroup_id.x * m_dim * K + workgroup_id.y * k_dim;
118+
let out_offset = (workgroup_id.x * K / k_dim + workgroup_id.y) * m_dim * k_dim;
119+
120+
// Syntax: subgroupMatrixLoad src_ptr, src_offset, is_col_major, src_stride
121+
var mat: subgroup_matrix_left<component_type, k_dim, m_dim> =
122+
subgroupMatrixLoad<subgroup_matrix_left<component_type, k_dim, m_dim>>(&input_a, in_offset, false, uniforms.K);
123+
subgroupMatrixStore(&output_a, out_offset, mat, false, k_dim);
124+
)MAIN_FN";
125+
return Status::OK();
126+
}
127+
128+
76129
Status GenerateShaderCodeOnIntel(ShaderHelper& shader, uint32_t nbits, int32_t config_index, bool has_zero_points) {
77130
auto& config = intel_supported_subgroup_matrix_configs[config_index];
78131
shader.AdditionalImplementation() << "alias component_type = " << ComponentTypeName[static_cast<uint32_t>(std::get<2>(config))] << ";\n"
79132
<< "alias result_component_type = " << ComponentTypeName[static_cast<uint32_t>(std::get<3>(config))] << ";\n"
80-
<< "const m_dim = " << std::get<4>(config) << ";\n"
81-
<< "const n_dim = " << std::get<5>(config) << ";\n"
82-
<< "const k_dim = " << std::get<6>(config) << ";\n";
133+
<< "const m_dim: u32 = " << std::get<4>(config) << ";\n"
134+
<< "const n_dim: u32 = " << std::get<5>(config) << ";\n"
135+
<< "const k_dim: u32 = " << std::get<6>(config) << ";\n";
83136

84137
shader.AdditionalImplementation() << R"ADDNL_FN(
85-
const tile_cols = 64;
86-
const tile_rows = 64;
87-
const tile_k = 32;
88-
const subtile_rows = 8;
89-
const quantization_block_size = 32;
90-
91-
var<workgroup> tile_B: array<component_type, tile_cols * tile_k>; // 64 x 32 - RxC
92-
)ADDNL_FN" << GenerateZeroPointReadingCode(nbits, has_zero_points, "component_type");
138+
const tile_cols: u32 = 64;
139+
const tile_rows: u32 = 64;
140+
const tile_k: u32 = 32;
141+
const subtile_rows: u32 = 8;
142+
const quantization_block_size: u32 = 32;
143+
144+
var<workgroup> tile_B: array<component_type, tile_cols * tile_k>; // 64 x 32 - RxC
145+
)ADDNL_FN" << GenerateZeroPointReadingCode(nbits, has_zero_points, "component_type");
93146
if (nbits == 4) {
94147
shader.AdditionalImplementation() << R"ADDNL_FN(
95148
fn loadSHMB(tile_base: u32, k_idx: u32, row: u32, c_idx: u32) {
@@ -153,7 +206,12 @@ Status GenerateShaderCodeOnIntel(ShaderHelper& shader, uint32_t nbits, int32_t c
153206
let a_global_base = workgroup_id.y * tile_rows;
154207
let b_global_base = workgroup_id.x * tile_cols;
155208
156-
let subtile_id = u32(local_idx / sg_size);
209+
let subgroup_id = u32(local_idx / sg_size);
210+
let a_subtile_num_per_tensor_row = u32(uniforms.K / k_dim);
211+
let a_subtile_num_per_tile_col = u32(tile_rows / m_dim);
212+
let a_subtile_id = (workgroup_id.y * a_subtile_num_per_tile_col + subgroup_id) * a_subtile_num_per_tensor_row;
213+
let a_subtile_size = m_dim * k_dim;
214+
var matrix_a_offset = a_subtile_id * a_subtile_size;
157215
158216
var matC00: subgroup_matrix_result<result_component_type, n_dim, m_dim>;
159217
var matC01: subgroup_matrix_result<result_component_type, n_dim, m_dim>;
@@ -167,9 +225,9 @@ Status GenerateShaderCodeOnIntel(ShaderHelper& shader, uint32_t nbits, int32_t c
167225
for (var step: u32 = 0; step < tile_k; step += k_dim)
168226
{
169227
// Load A from global memory.
170-
let matrix_a_offset = (a_global_base + subtile_id * subtile_rows) * uniforms.K + kidx + step;
171228
// Syntax: subgroupMatrixLoad src_ptr, src_offset, is_col_major, src_stride
172-
var matA0: subgroup_matrix_left<component_type, k_dim, m_dim> = subgroupMatrixLoad<subgroup_matrix_left<component_type, k_dim, m_dim>>(&input_a, matrix_a_offset, false, uniforms.K);
229+
var matA0: subgroup_matrix_left<component_type, k_dim, m_dim> = subgroupMatrixLoad<subgroup_matrix_left<component_type, k_dim, m_dim>>(&input_a, matrix_a_offset, false, k_dim);
230+
matrix_a_offset += a_subtile_size;
173231
174232
// Load B from shared local memory.
175233
// tile_B is stored as column major.
@@ -192,10 +250,10 @@ Status GenerateShaderCodeOnIntel(ShaderHelper& shader, uint32_t nbits, int32_t c
192250
193251
// Write out
194252
let matrix_c_offset = (a_global_base) * uniforms.N + b_global_base;
195-
subgroupMatrixStore(&output, matrix_c_offset + subtile_id * m_dim * uniforms.N, matC00, false, uniforms.N);
196-
subgroupMatrixStore(&output, matrix_c_offset + subtile_id * m_dim * uniforms.N + n_dim, matC01, false, uniforms.N);
197-
subgroupMatrixStore(&output, matrix_c_offset + subtile_id * m_dim * uniforms.N + 2 * n_dim, matC02, false, uniforms.N);
198-
subgroupMatrixStore(&output, matrix_c_offset + subtile_id * m_dim * uniforms.N + 3 * n_dim, matC03, false, uniforms.N);
253+
subgroupMatrixStore(&output, matrix_c_offset + subgroup_id * m_dim * uniforms.N, matC00, false, uniforms.N);
254+
subgroupMatrixStore(&output, matrix_c_offset + subgroup_id * m_dim * uniforms.N + n_dim, matC01, false, uniforms.N);
255+
subgroupMatrixStore(&output, matrix_c_offset + subgroup_id * m_dim * uniforms.N + 2 * n_dim, matC02, false, uniforms.N);
256+
subgroupMatrixStore(&output, matrix_c_offset + subgroup_id * m_dim * uniforms.N + 3 * n_dim, matC03, false, uniforms.N);
199257
)MAIN_FN";
200258

201259
return Status::OK();
@@ -426,6 +484,30 @@ Status ApplySubgroupMatrixMatMulNBits(const Tensor* a, const Tensor* b, const Te
426484
int32_t config_index,
427485
onnxruntime::webgpu::ComputeContext& context,
428486
Tensor* y) {
487+
const auto& config = intel_supported_subgroup_matrix_configs[config_index];
488+
const auto component_type = ComponentTypeName[static_cast<uint32_t>(std::get<2>(config))];
489+
const auto m = std::get<4>(config);
490+
const auto k = std::get<6>(config);
491+
492+
// Optimize the layout of input matrix A(MxK) for SubgroupMatrixLoad.
493+
LayoutProgram layout_program{m, k, component_type};
494+
constexpr uint32_t kSubgroupSize = 32;
495+
layout_program.SetWorkgroupSize(kSubgroupSize);
496+
497+
const auto dispatch_group_size_x = (M + m - 1) / m;
498+
ORT_ENFORCE(K % k == 0, "K must be a multiple of ", k);
499+
const auto dispatch_group_size_y = K / k;
500+
// Each workgroup will process one subgroup matrix of size m x k.
501+
layout_program.SetDispatchGroupSize(dispatch_group_size_x, dispatch_group_size_y, 1);
502+
503+
TensorShape a_layout_shape{dispatch_group_size_x * m, K};
504+
Tensor a_layout = context.CreateGPUTensor(a->DataType(), a_layout_shape);
505+
layout_program.AddInputs({{a, ProgramTensorMetadataDependency::TypeAndRank, 1}})
506+
.AddOutputs({{&a_layout, ProgramTensorMetadataDependency::Rank, a_layout.Shape(), 1}})
507+
.AddUniformVariables({{static_cast<uint32_t>(M)},
508+
{static_cast<uint32_t>(K)}});
509+
ORT_RETURN_IF_ERROR(context.RunProgram(layout_program));
510+
429511
uint32_t tile_size_a = 32;
430512
uint32_t work_group_size = 128;
431513
constexpr uint32_t kTileSizeB = 64;
@@ -441,7 +523,7 @@ Status ApplySubgroupMatrixMatMulNBits(const Tensor* a, const Tensor* b, const Te
441523
mul_program.SetDispatchGroupSize(
442524
(N + kTileSizeB - 1) / kTileSizeB,
443525
(M + tile_size_a - 1) / tile_size_a, 1);
444-
mul_program.AddInputs({{a, ProgramTensorMetadataDependency::TypeAndRank, 1},
526+
mul_program.AddInputs({{&a_layout, ProgramTensorMetadataDependency::TypeAndRank, 1},
445527
{b, ProgramTensorMetadataDependency::TypeAndRank, static_cast<int>(nbits == 4 ? kU32Components : 2 * kU32Components)},
446528
{scales, ProgramTensorMetadataDependency::TypeAndRank, 1}})
447529
.AddUniformVariables({{static_cast<uint32_t>(M)},

vk_build.bat

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
build.bat --config RelWithDebInfo --parallel --skip_submodule_sync --skip_tests --parallel --use_webgpu --build_shared_lib --enable_pybind --build_wheel --cmake_extra_defines onnxruntime_ENABLE_DAWN_BACKEND_VULKAN=ON --cmake_extra_defines onnxruntime_ENABLE_DAWN_BACKEND_D3D12=OFF --cmake_generator "Visual Studio 17 2022"

0 commit comments

Comments
 (0)