@@ -73,23 +73,76 @@ bool IsSubgroupMatrixConfigSupportedOnIntel(onnxruntime::webgpu::ComputeContext&
73
73
return false ;
74
74
}
75
75
76
+ // This program optimizes the layout of input matrix A(MxK) for SubgroupMatrixLoad, so that all elements of each
77
+ // subgroup matrix(mxk) are arranged continuously in memory.
78
+ // Take "M = 4, K = 4, m = 2, k = 2" as an example, the input matrix A is arranged in row-major order as follows:
79
+ // d00, d01, | d02, d03,
80
+ // d10, d11, | d12, d13,
81
+ // ---------------------
82
+ // d20, d21, | d22, d23,
83
+ // d30, d31, | d32, d33,
84
+ //
85
+ // The layout program rearranges the input matrix A to be in the following order:
86
+ // d00, d01, d10, d11,
87
+ // -------------------
88
+ // d02, d03, d12, d13,
89
+ // -------------------
90
+ // d20, d21, d30, d31,
91
+ // -------------------
92
+ // d22, d23, d32, d33,
93
+ class LayoutProgram final : public Program<LayoutProgram> {
94
+ public:
95
+ LayoutProgram (uint32_t m, uint32_t k, std::string_view component_type) : Program{" SubgroupMatrixMatMulLayout" },
96
+ m_ (m), k_(k), component_type_(component_type) {}
97
+ Status GenerateShaderCode (ShaderHelper& sh) const override ;
98
+ WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES (
99
+ {" M" , ProgramUniformVariableDataType::Uint32},
100
+ {" K" , ProgramUniformVariableDataType::Uint32});
101
+ private:
102
+ uint32_t m_;
103
+ uint32_t k_;
104
+ std::string_view component_type_;
105
+ };
106
+
107
+ Status LayoutProgram::GenerateShaderCode (ShaderHelper& shader) const {
108
+ shader.AddInput (" input_a" , ShaderUsage::UseUniform);
109
+ shader.AddOutput (" output_a" , ShaderUsage::UseUniform);
110
+ shader.AdditionalImplementation () << " alias component_type = " << component_type_ << " ;\n "
111
+ << " const m_dim: u32 = " << m_ << " ;\n "
112
+ << " const k_dim: u32 = " << k_ << " ;\n " ;
113
+
114
+ shader.MainFunctionBody () << R"MAIN_FN(
115
+ let M = uniforms.M;
116
+ let K = uniforms.K;
117
+ let in_offset = workgroup_id.x * m_dim * K + workgroup_id.y * k_dim;
118
+ let out_offset = (workgroup_id.x * K / k_dim + workgroup_id.y) * m_dim * k_dim;
119
+
120
+ // Syntax: subgroupMatrixLoad src_ptr, src_offset, is_col_major, src_stride
121
+ var mat: subgroup_matrix_left<component_type, k_dim, m_dim> =
122
+ subgroupMatrixLoad<subgroup_matrix_left<component_type, k_dim, m_dim>>(&input_a, in_offset, false, uniforms.K);
123
+ subgroupMatrixStore(&output_a, out_offset, mat, false, k_dim);
124
+ )MAIN_FN" ;
125
+ return Status::OK ();
126
+ }
127
+
128
+
76
129
Status GenerateShaderCodeOnIntel (ShaderHelper& shader, uint32_t nbits, int32_t config_index, bool has_zero_points) {
77
130
auto & config = intel_supported_subgroup_matrix_configs[config_index];
78
131
shader.AdditionalImplementation () << " alias component_type = " << ComponentTypeName[static_cast <uint32_t >(std::get<2 >(config))] << " ;\n "
79
132
<< " alias result_component_type = " << ComponentTypeName[static_cast <uint32_t >(std::get<3 >(config))] << " ;\n "
80
- << " const m_dim = " << std::get<4 >(config) << " ;\n "
81
- << " const n_dim = " << std::get<5 >(config) << " ;\n "
82
- << " const k_dim = " << std::get<6 >(config) << " ;\n " ;
133
+ << " const m_dim: u32 = " << std::get<4 >(config) << " ;\n "
134
+ << " const n_dim: u32 = " << std::get<5 >(config) << " ;\n "
135
+ << " const k_dim: u32 = " << std::get<6 >(config) << " ;\n " ;
83
136
84
137
shader.AdditionalImplementation () << R"ADDNL_FN(
85
- const tile_cols = 64;
86
- const tile_rows = 64;
87
- const tile_k = 32;
88
- const subtile_rows = 8;
89
- const quantization_block_size = 32;
90
-
91
- var<workgroup> tile_B: array<component_type, tile_cols * tile_k>; // 64 x 32 - RxC
92
- )ADDNL_FN" << GenerateZeroPointReadingCode (nbits, has_zero_points, " component_type" );
138
+ const tile_cols: u32 = 64;
139
+ const tile_rows: u32 = 64;
140
+ const tile_k: u32 = 32;
141
+ const subtile_rows: u32 = 8;
142
+ const quantization_block_size: u32 = 32;
143
+
144
+ var<workgroup> tile_B: array<component_type, tile_cols * tile_k>; // 64 x 32 - RxC
145
+ )ADDNL_FN" << GenerateZeroPointReadingCode (nbits, has_zero_points, " component_type" );
93
146
if (nbits == 4 ) {
94
147
shader.AdditionalImplementation () << R"ADDNL_FN(
95
148
fn loadSHMB(tile_base: u32, k_idx: u32, row: u32, c_idx: u32) {
@@ -153,7 +206,12 @@ Status GenerateShaderCodeOnIntel(ShaderHelper& shader, uint32_t nbits, int32_t c
153
206
let a_global_base = workgroup_id.y * tile_rows;
154
207
let b_global_base = workgroup_id.x * tile_cols;
155
208
156
- let subtile_id = u32(local_idx / sg_size);
209
+ let subgroup_id = u32(local_idx / sg_size);
210
+ let a_subtile_num_per_tensor_row = u32(uniforms.K / k_dim);
211
+ let a_subtile_num_per_tile_col = u32(tile_rows / m_dim);
212
+ let a_subtile_id = (workgroup_id.y * a_subtile_num_per_tile_col + subgroup_id) * a_subtile_num_per_tensor_row;
213
+ let a_subtile_size = m_dim * k_dim;
214
+ var matrix_a_offset = a_subtile_id * a_subtile_size;
157
215
158
216
var matC00: subgroup_matrix_result<result_component_type, n_dim, m_dim>;
159
217
var matC01: subgroup_matrix_result<result_component_type, n_dim, m_dim>;
@@ -167,9 +225,9 @@ Status GenerateShaderCodeOnIntel(ShaderHelper& shader, uint32_t nbits, int32_t c
167
225
for (var step: u32 = 0; step < tile_k; step += k_dim)
168
226
{
169
227
// Load A from global memory.
170
- let matrix_a_offset = (a_global_base + subtile_id * subtile_rows) * uniforms.K + kidx + step;
171
228
// Syntax: subgroupMatrixLoad src_ptr, src_offset, is_col_major, src_stride
172
- var matA0: subgroup_matrix_left<component_type, k_dim, m_dim> = subgroupMatrixLoad<subgroup_matrix_left<component_type, k_dim, m_dim>>(&input_a, matrix_a_offset, false, uniforms.K);
229
+ var matA0: subgroup_matrix_left<component_type, k_dim, m_dim> = subgroupMatrixLoad<subgroup_matrix_left<component_type, k_dim, m_dim>>(&input_a, matrix_a_offset, false, k_dim);
230
+ matrix_a_offset += a_subtile_size;
173
231
174
232
// Load B from shared local memory.
175
233
// tile_B is stored as column major.
@@ -192,10 +250,10 @@ Status GenerateShaderCodeOnIntel(ShaderHelper& shader, uint32_t nbits, int32_t c
192
250
193
251
// Write out
194
252
let matrix_c_offset = (a_global_base) * uniforms.N + b_global_base;
195
- subgroupMatrixStore(&output, matrix_c_offset + subtile_id * m_dim * uniforms.N, matC00, false, uniforms.N);
196
- subgroupMatrixStore(&output, matrix_c_offset + subtile_id * m_dim * uniforms.N + n_dim, matC01, false, uniforms.N);
197
- subgroupMatrixStore(&output, matrix_c_offset + subtile_id * m_dim * uniforms.N + 2 * n_dim, matC02, false, uniforms.N);
198
- subgroupMatrixStore(&output, matrix_c_offset + subtile_id * m_dim * uniforms.N + 3 * n_dim, matC03, false, uniforms.N);
253
+ subgroupMatrixStore(&output, matrix_c_offset + subgroup_id * m_dim * uniforms.N, matC00, false, uniforms.N);
254
+ subgroupMatrixStore(&output, matrix_c_offset + subgroup_id * m_dim * uniforms.N + n_dim, matC01, false, uniforms.N);
255
+ subgroupMatrixStore(&output, matrix_c_offset + subgroup_id * m_dim * uniforms.N + 2 * n_dim, matC02, false, uniforms.N);
256
+ subgroupMatrixStore(&output, matrix_c_offset + subgroup_id * m_dim * uniforms.N + 3 * n_dim, matC03, false, uniforms.N);
199
257
)MAIN_FN" ;
200
258
201
259
return Status::OK ();
@@ -426,6 +484,30 @@ Status ApplySubgroupMatrixMatMulNBits(const Tensor* a, const Tensor* b, const Te
426
484
int32_t config_index,
427
485
onnxruntime::webgpu::ComputeContext& context,
428
486
Tensor* y) {
487
+ const auto & config = intel_supported_subgroup_matrix_configs[config_index];
488
+ const auto component_type = ComponentTypeName[static_cast <uint32_t >(std::get<2 >(config))];
489
+ const auto m = std::get<4 >(config);
490
+ const auto k = std::get<6 >(config);
491
+
492
+ // Optimize the layout of input matrix A(MxK) for SubgroupMatrixLoad.
493
+ LayoutProgram layout_program{m, k, component_type};
494
+ constexpr uint32_t kSubgroupSize = 32 ;
495
+ layout_program.SetWorkgroupSize (kSubgroupSize );
496
+
497
+ const auto dispatch_group_size_x = (M + m - 1 ) / m;
498
+ ORT_ENFORCE (K % k == 0 , " K must be a multiple of " , k);
499
+ const auto dispatch_group_size_y = K / k;
500
+ // Each workgroup will process one subgroup matrix of size m x k.
501
+ layout_program.SetDispatchGroupSize (dispatch_group_size_x, dispatch_group_size_y, 1 );
502
+
503
+ TensorShape a_layout_shape{dispatch_group_size_x * m, K};
504
+ Tensor a_layout = context.CreateGPUTensor (a->DataType (), a_layout_shape);
505
+ layout_program.AddInputs ({{a, ProgramTensorMetadataDependency::TypeAndRank, 1 }})
506
+ .AddOutputs ({{&a_layout, ProgramTensorMetadataDependency::Rank, a_layout.Shape (), 1 }})
507
+ .AddUniformVariables ({{static_cast <uint32_t >(M)},
508
+ {static_cast <uint32_t >(K)}});
509
+ ORT_RETURN_IF_ERROR (context.RunProgram (layout_program));
510
+
429
511
uint32_t tile_size_a = 32 ;
430
512
uint32_t work_group_size = 128 ;
431
513
constexpr uint32_t kTileSizeB = 64 ;
@@ -441,7 +523,7 @@ Status ApplySubgroupMatrixMatMulNBits(const Tensor* a, const Tensor* b, const Te
441
523
mul_program.SetDispatchGroupSize (
442
524
(N + kTileSizeB - 1 ) / kTileSizeB ,
443
525
(M + tile_size_a - 1 ) / tile_size_a, 1 );
444
- mul_program.AddInputs ({{a , ProgramTensorMetadataDependency::TypeAndRank, 1 },
526
+ mul_program.AddInputs ({{&a_layout , ProgramTensorMetadataDependency::TypeAndRank, 1 },
445
527
{b, ProgramTensorMetadataDependency::TypeAndRank, static_cast <int >(nbits == 4 ? kU32Components : 2 * kU32Components )},
446
528
{scales, ProgramTensorMetadataDependency::TypeAndRank, 1 }})
447
529
.AddUniformVariables ({{static_cast <uint32_t >(M)},
0 commit comments