Skip to content

Commit f190e70

Browse files
authored
add webgpu support for GatherBlockQuantized (#25413)
add webgpu support for GatherBlockQuantized
1 parent 1d00bff commit f190e70

File tree

5 files changed

+336
-10
lines changed

5 files changed

+336
-10
lines changed
Lines changed: 246 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,246 @@
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// Licensed under the MIT License.
3+
4+
#include "core/providers/webgpu/shader_helper.h"
5+
#include "core/providers/webgpu/webgpu_utils.h"
6+
#include "core/providers/webgpu/webgpu_supported_types.h"
7+
#include "contrib_ops/webgpu/webgpu_contrib_kernels.h"
8+
#include "contrib_ops/webgpu/quantization/gather_block_quantized.h"
9+
10+
namespace onnxruntime {
11+
namespace contrib {
12+
namespace webgpu {
13+
14+
using namespace onnxruntime::webgpu;
15+
using onnxruntime::webgpu::ComputeContext;
16+
17+
Status GatherBlockQuantizedProgram::GenerateShaderCode(ShaderHelper& shader) const {
18+
const auto& x = shader.AddInput("input", ShaderUsage::UseElementTypeAlias);
19+
const auto& x_shape = shader.AddIndices("input_shape", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias);
20+
const auto& indices = shader.AddInput("indices", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias | ShaderUsage::UseIndicesToOffset);
21+
const auto& scales = shader.AddInput("scales", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias | ShaderUsage::UseValueTypeAlias);
22+
const auto& output = shader.AddOutput("output", ShaderUsage::UseUniform | ShaderUsage::UseShapeAndStride | ShaderUsage::UseValueTypeAlias);
23+
24+
bool is_4bit = bits_ == 4;
25+
const std::string unpack = (is_signed_) ? "unpack4xI8" : "unpack4xU8";
26+
27+
shader.MainFunctionBody()
28+
<< shader.GuardAgainstOutOfBoundsWorkgroupSizes("uniforms.output_size")
29+
<< "let output_indices = " << output.OffsetToIndices("global_idx") << ";\n";
30+
31+
if (indices_rank_ > 1) {
32+
shader.MainFunctionBody()
33+
<< "var indices_indices = indices_indices_t(0);\n"
34+
<< "for (var i: u32 = 0; i < " << indices_rank_ << "; i++) {\n"
35+
<< " let index = " << output.IndicesGet("output_indices", "uniforms.gather_axis + i") << ";\n"
36+
<< " " << indices.IndicesSet("indices_indices", "i", "index") << ";\n};\n";
37+
} else {
38+
shader.MainFunctionBody()
39+
<< "let indices_indices = " << output.IndicesGet("output_indices", "uniforms.gather_axis") << ";\n";
40+
}
41+
shader.MainFunctionBody()
42+
<< "var data_indices = input_shape_indices_t(0);\n"
43+
<< "for (var i: u32 = 0; i < uniforms.gather_axis; i++) {\n"
44+
<< " let index = " << output.IndicesGet("output_indices", "i") << ";\n "
45+
<< x_shape.IndicesSet("data_indices", "i", "index") << ";\n};\n"
46+
<< "var index_from_indices = " << indices.GetByIndices("indices_indices") << ";\n"
47+
<< "if (index_from_indices < 0) { index_from_indices += " << x_shape_[gather_axis_] << ";}\n"
48+
<< x_shape.IndicesSet("data_indices", "uniforms.gather_axis", "u32(index_from_indices)") << ";\n"
49+
<< "for (var i = uniforms.gather_axis + 1; i < " << output_shape_.NumDimensions() << "; i++) {\n"
50+
<< " let index = " << output.IndicesGet("output_indices", "i + " + std::to_string(indices_rank_ - 1)) << ";\n "
51+
<< x_shape.IndicesSet("data_indices", "i", "index") << ";\n};\n"
52+
<< " let data_offset = " << x_shape.IndicesToOffset("data_indices") << ";\n";
53+
54+
if (is_4bit) {
55+
shader.MainFunctionBody()
56+
<< " let data_index = data_offset % 8;\n"
57+
<< " let packed_4bit_quantized_data = " << x.GetByOffset("data_offset / 8") << ";\n"
58+
<< " let packed_8bit_quantized_data = (packed_4bit_quantized_data >> (4 * (data_index % 2))) & 0x0f0f0f0f;\n"
59+
<< " let quantized_data_vec = " << unpack << "(u32(packed_8bit_quantized_data));\n"
60+
<< " var quantized_data = quantized_data_vec[data_index / 2];\n";
61+
if (is_signed_) {
62+
shader.MainFunctionBody()
63+
<< " if((quantized_data & 0x8) != 0) { quantized_data = quantized_data - 16 ;};\n";
64+
}
65+
} else {
66+
shader.MainFunctionBody()
67+
<< " let data_index = data_offset % 4;\n"
68+
<< " let packed_8bit_quantized_data = " << x.GetByOffset("data_offset / 4") << ";\n"
69+
<< " let quantized_data_vec = " << unpack << "(u32(packed_8bit_quantized_data));\n"
70+
<< " var quantized_data = quantized_data_vec[data_index];\n";
71+
}
72+
73+
shader.MainFunctionBody()
74+
<< " var scale_indices = data_indices;\n"
75+
<< " let quantize_axis_index = " << scales.IndicesGet("data_indices", "uniforms.quantize_axis") << "/ uniforms.block_size;\n "
76+
<< scales.IndicesSet("scale_indices", "uniforms.quantize_axis", "quantize_axis_index") << ";\n"
77+
<< " var scale = " << scales.GetByIndices("scale_indices") << ";\n";
78+
79+
if (!has_zeropoint_) {
80+
const std::string default_zero_point = is_uint8_ ? is_4bit ? "input_element_t(8)" : "input_element_t(128)" : "input_element_t(0)";
81+
shader.MainFunctionBody()
82+
<< " let zero_point = " << default_zero_point << ";\n";
83+
} else {
84+
const auto& zero_point = shader.AddInput("zero_point", ShaderUsage::None);
85+
shader.MainFunctionBody()
86+
<< " let zero_point_indices = scale_indices;\n"
87+
<< " let zero_point_offset = " << scales.IndicesToOffset("zero_point_indices") << ";\n";
88+
if (is_4bit) {
89+
shader.MainFunctionBody()
90+
<< " let zero_point_index = zero_point_offset % 8;\n"
91+
<< " let packed_4bit_zero_points = " << zero_point.GetByOffset("zero_point_offset / 8") << ";\n"
92+
<< " let packed_8bit_zero_points = (packed_4bit_zero_points >> (4 * (zero_point_index % 2))) & 0x0f0f0f0f;\n"
93+
<< " let zero_point_vec = " << unpack << "(u32(packed_8bit_zero_points));\n"
94+
<< " var zero_point = zero_point_vec[zero_point_index / 2];\n";
95+
} else {
96+
shader.MainFunctionBody()
97+
<< " let zero_point_index = zero_point_offset % 4;\n"
98+
<< " let packed_8bit_zero_points = " << zero_point.GetByOffset("zero_point_offset / 4") << ";\n"
99+
<< " let zero_point_vec = " << unpack << "(u32(packed_8bit_zero_points));\n"
100+
<< " var zero_point = zero_point_vec[zero_point_index];\n";
101+
}
102+
if (is_signed_) {
103+
shader.MainFunctionBody()
104+
<< " if((zero_point & 0x8) != 0) { zero_point = zero_point - 16 ;};\n";
105+
}
106+
}
107+
shader.MainFunctionBody()
108+
<< " let dequantized_data = (output_value_t(quantized_data) - output_value_t(zero_point)) * scale;\n "
109+
<< output.SetByOffset("global_idx", "dequantized_data") << ";\n";
110+
111+
return Status::OK();
112+
}
113+
114+
TensorShapeVector splice(TensorShapeVector vec, size_t start, size_t deleteCount, const TensorShapeVector toInsert = {}) {
115+
TensorShapeVector new_vec;
116+
117+
for (size_t i = 0; i < vec.size(); i++) {
118+
if (i < start) {
119+
new_vec.push_back(vec[i]);
120+
} else if (i == start) {
121+
new_vec.insert(new_vec.end(), toInsert.begin(), toInsert.end());
122+
} else if (i >= start + deleteCount) {
123+
new_vec.push_back(vec[i]);
124+
}
125+
}
126+
return new_vec;
127+
}
128+
129+
Status GatherBlockQuantized::ComputeInternal(ComputeContext& context) const {
130+
const auto* x = context.Input(0);
131+
const auto* indices = context.Input(1);
132+
const auto* scales = context.Input(2);
133+
const auto* zero_points = context.Input(3);
134+
135+
int x_rank = static_cast<int>(x->Shape().NumDimensions());
136+
int64_t x_dtype = x->GetElementType();
137+
bool is_signed = x_dtype == ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8 || x_dtype == ONNX_TENSOR_ELEMENT_DATA_TYPE_INT4;
138+
bool is_int8 = x_dtype == ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8 || x_dtype == ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8;
139+
140+
std::optional<Tensor> data_representation_4bit;
141+
std::optional<Tensor> zero_points_representation_4bit;
142+
if (bits_ == 4 && is_int8) {
143+
TensorShape data_representation_4bit_shape{x->Shape()};
144+
MLDataType new_dtype = (x_dtype == ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8) ? DataTypeImpl::GetType<UInt4x2>() : DataTypeImpl::GetType<Int4x2>();
145+
auto memory_info = OrtMemoryInfo{
146+
"WebGPU_Buffer",
147+
OrtDeviceAllocator,
148+
OrtDevice{OrtDevice::GPU, OrtDevice::MemType::DEFAULT, OrtDevice::VendorIds::NONE, 0}};
149+
150+
data_representation_4bit_shape[x_rank - 1] = data_representation_4bit_shape[x_rank - 1] * 2;
151+
data_representation_4bit.emplace(
152+
new_dtype,
153+
data_representation_4bit_shape,
154+
const_cast<void*>(x->DataRaw()),
155+
memory_info);
156+
157+
if (zero_points) {
158+
TensorShape zero_points_representation_4bit_shape{zero_points->Shape()};
159+
zero_points_representation_4bit_shape[zero_points->Shape().NumDimensions() - 1] =
160+
zero_points_representation_4bit_shape[zero_points->Shape().NumDimensions() - 1] * 2;
161+
zero_points_representation_4bit.emplace(
162+
new_dtype,
163+
zero_points_representation_4bit_shape,
164+
const_cast<void*>(zero_points->DataRaw()),
165+
memory_info);
166+
}
167+
x = data_representation_4bit.has_value() ? &data_representation_4bit.value() : x;
168+
zero_points = zero_points_representation_4bit.has_value() ? &zero_points_representation_4bit.value() : zero_points;
169+
}
170+
171+
const auto& x_shape = x->Shape();
172+
173+
size_t indices_rank = indices->Shape().NumDimensions();
174+
const auto scales_shape = scales->Shape();
175+
size_t scales_rank = scales_shape.NumDimensions();
176+
int gather_axis = (gather_axis_ >= 0) ? gather_axis_ : gather_axis_ + x_rank;
177+
int quantize_axis = (quantize_axis_ >= 0) ? quantize_axis_ : quantize_axis_ + x_rank;
178+
179+
ORT_RETURN_IF_NOT(x_shape.NumDimensions() == scales_rank,
180+
"data and scales must have the same rank.");
181+
for (size_t i = 0; i < x_shape.NumDimensions(); ++i) {
182+
ORT_RETURN_IF_NOT(i == static_cast<size_t>(quantize_axis)
183+
? (x_shape[i] * 1 + block_size_ - 1) / block_size_ == scales_shape[i]
184+
: x_shape[i] == scales_shape[i],
185+
"data and scales do not match shapes.");
186+
}
187+
188+
TensorShape output_shape = splice(x_shape.AsShapeVector(), gather_axis, 1, indices->Shape().AsShapeVector());
189+
int64_t output_size = output_shape.Size();
190+
auto* output_tensor = context.Output(0, output_shape);
191+
192+
GatherBlockQuantizedProgram program{is_signed, is_int8, indices_rank, gather_axis, bits_, zero_points != nullptr, x_shape, output_shape};
193+
194+
program
195+
.AddInputs({{x, ProgramTensorMetadataDependency::Type, ProgramInput::Flatten, (bits_ == 4) ? 8 : 4}})
196+
.AddIndices(x_shape)
197+
.AddInputs({{indices, ProgramTensorMetadataDependency::TypeAndRank}})
198+
.AddInputs({{scales, ProgramTensorMetadataDependency::TypeAndRank}})
199+
.AddOutput({output_tensor, ProgramTensorMetadataDependency::None})
200+
.SetDispatchGroupSize((output_size + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE)
201+
.AddUniformVariables({{static_cast<uint32_t>(output_size)}})
202+
.AddUniformVariables({{static_cast<uint32_t>(quantize_axis)}})
203+
.AddUniformVariables({{static_cast<uint32_t>(gather_axis)}})
204+
.AddUniformVariables({{static_cast<uint32_t>(block_size_)}})
205+
.CacheHint(std::to_string(gather_axis), std::to_string(quantize_axis), std::to_string(block_size_));
206+
207+
if (zero_points != nullptr) {
208+
ORT_RETURN_IF_NOT(scales_shape == zero_points->Shape(),
209+
"scales and zero_points must have the same shape.");
210+
auto zero_points_shape = zero_points->Shape();
211+
program.AddInputs({{zero_points, ProgramTensorMetadataDependency::None, ProgramInput::Flatten, (bits_ == 4) ? 8 : 4}});
212+
}
213+
214+
return context.RunProgram(program);
215+
}
216+
217+
namespace {
218+
const std::vector<MLDataType>& GatherBlockQuantizedT1Constraint() {
219+
static std::vector<MLDataType> types{
220+
DataTypeImpl::GetTensorType<Int4x2>(),
221+
DataTypeImpl::GetTensorType<UInt4x2>(),
222+
DataTypeImpl::GetTensorType<uint8_t>()};
223+
return types;
224+
}
225+
const std::vector<MLDataType>& GatherBlockQuantizedTindConstraint() {
226+
static std::vector<MLDataType> types{
227+
DataTypeImpl::GetTensorType<int32_t>(),
228+
DataTypeImpl::GetTensorType<int64_t>()};
229+
return types;
230+
}
231+
} // namespace
232+
233+
ONNX_OPERATOR_KERNEL_EX(
234+
GatherBlockQuantized,
235+
kMSDomain,
236+
1,
237+
kWebGpuExecutionProvider,
238+
(*KernelDefBuilder::Create())
239+
.TypeConstraint("T1", GatherBlockQuantizedT1Constraint())
240+
.TypeConstraint("T2", WebGpuSupportedFloatTypes())
241+
.TypeConstraint("Tind", GatherBlockQuantizedTindConstraint()),
242+
GatherBlockQuantized);
243+
244+
} // namespace webgpu
245+
} // namespace contrib
246+
} // namespace onnxruntime
Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// Licensed under the MIT License.
3+
4+
#pragma once
5+
6+
#include "core/providers/webgpu/program.h"
7+
#include "core/providers/webgpu/webgpu_kernel.h"
8+
9+
namespace onnxruntime {
10+
namespace contrib {
11+
namespace webgpu {
12+
13+
using namespace onnxruntime::webgpu;
14+
using onnxruntime::webgpu::ComputeContext;
15+
16+
class GatherBlockQuantizedProgram final : public Program<GatherBlockQuantizedProgram> {
17+
public:
18+
GatherBlockQuantizedProgram(const bool is_signed, const bool is_uint8, size_t indices_rank, int gather_axis, int bits, bool has_zeropoint,
19+
TensorShape x_shape, TensorShape output_shape) : Program<GatherBlockQuantizedProgram>{"GatherBlockQuantized"},
20+
is_signed_{is_signed},
21+
is_uint8_{is_uint8},
22+
indices_rank_{indices_rank},
23+
gather_axis_{gather_axis},
24+
bits_{bits},
25+
has_zeropoint_{has_zeropoint},
26+
x_shape_{x_shape},
27+
output_shape_{output_shape} {}
28+
29+
Status GenerateShaderCode(ShaderHelper& sh) const override;
30+
31+
WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"output_size", ProgramUniformVariableDataType::Uint32},
32+
{"quantize_axis", ProgramUniformVariableDataType::Uint32},
33+
{"gather_axis", ProgramUniformVariableDataType::Uint32},
34+
{"block_size", ProgramUniformVariableDataType::Uint32});
35+
36+
private:
37+
bool is_signed_;
38+
bool is_uint8_;
39+
size_t indices_rank_;
40+
int gather_axis_;
41+
int bits_;
42+
bool has_zeropoint_;
43+
TensorShape x_shape_;
44+
TensorShape output_shape_;
45+
};
46+
47+
class GatherBlockQuantized final : public WebGpuKernel {
48+
public:
49+
GatherBlockQuantized(const OpKernelInfo& info) : WebGpuKernel(info) {
50+
gather_axis_ = static_cast<int>(info.GetAttrOrDefault<int64_t>("gather_axis", 0));
51+
block_size_ = static_cast<int>(info.GetAttrOrDefault<int64_t>("block_size", 128));
52+
quantize_axis_ = static_cast<int>(info.GetAttrOrDefault<int64_t>("quantize_axis", 1));
53+
bits_ = static_cast<int>(info.GetAttrOrDefault<int64_t>("bits", 4));
54+
55+
ORT_ENFORCE(bits_ == 4 || bits_ == 8, "'bits' must be 4 or 8.");
56+
ORT_ENFORCE(block_size_ >= 16 && ((block_size_ - 1) & block_size_) == 0,
57+
"'block_size' must be 2's power and not less than 16.");
58+
}
59+
60+
Status ComputeInternal(ComputeContext& context) const override;
61+
62+
private:
63+
int gather_axis_;
64+
int quantize_axis_;
65+
int block_size_;
66+
int bits_;
67+
};
68+
69+
} // namespace webgpu
70+
} // namespace contrib
71+
} // namespace onnxruntime

onnxruntime/contrib_ops/webgpu/webgpu_contrib_kernels.cc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, Bi
1414
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, BiasSplitGelu);
1515
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, FastGelu);
1616
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, FusedConv);
17+
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, GatherBlockQuantized);
1718
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, Gelu);
1819
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, GroupQueryAttention);
1920
// LayerNormalization used to be a contrib op that (incorrectly) used kOnnxDomain so we need to version it
@@ -40,6 +41,7 @@ Status RegisterWebGpuContribKernels(KernelRegistry& kernel_registry) {
4041
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, BiasAdd)>,
4142
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, BiasSplitGelu)>,
4243
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, FastGelu)>,
44+
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, GatherBlockQuantized)>,
4345
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, FusedConv)>,
4446
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, Gelu)>,
4547
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, GroupQueryAttention)>,

onnxruntime/core/providers/webgpu/shader_variable.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -268,8 +268,8 @@ void ShaderVariableHelper::Impl(std::ostream& ss) const {
268268
// Implementation of "fn get_{name}_by_indices"
269269
if (usage_ & ShaderUsage::UseGetByIndices) {
270270
if (rank_ >= 2) {
271-
SS_APPEND(ss, "fn get_", name_, "_by_indices(indices: ", IndicesType(), ")->", ValueType(), " {\n");
272-
SS_APPEND(ss, " return ", GetByOffset("i2o_" + name_ + "(indices)"), ";\n");
271+
SS_APPEND(ss, "fn get_", name_, "_by_indices(indices_fnarg: ", IndicesType(), ")->", ValueType(), " {\n");
272+
SS_APPEND(ss, " return ", GetByOffset("i2o_" + name_ + "(indices_fnarg)"), ";\n");
273273
SS_APPEND(ss, "}\n");
274274
}
275275
}

0 commit comments

Comments
 (0)