Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
246 changes: 246 additions & 0 deletions onnxruntime/contrib_ops/webgpu/quantization/gather_block_quantized.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,246 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#include "core/providers/webgpu/shader_helper.h"
#include "core/providers/webgpu/webgpu_utils.h"
#include "core/providers/webgpu/webgpu_supported_types.h"
#include "contrib_ops/webgpu/webgpu_contrib_kernels.h"
#include "contrib_ops/webgpu/quantization/gather_block_quantized.h"

namespace onnxruntime {
namespace contrib {
namespace webgpu {

using namespace onnxruntime::webgpu;

Check warning on line 14 in onnxruntime/contrib_ops/webgpu/quantization/gather_block_quantized.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Do not use namespace using-directives. Use using-declarations instead. [build/namespaces] [5] Raw Output: onnxruntime/contrib_ops/webgpu/quantization/gather_block_quantized.cc:14: Do not use namespace using-directives. Use using-declarations instead. [build/namespaces] [5]
using onnxruntime::webgpu::ComputeContext;

Status GatherBlockQuantizedProgram::GenerateShaderCode(ShaderHelper& shader) const {
const auto& x = shader.AddInput("input", ShaderUsage::UseElementTypeAlias);
const auto& x_shape = shader.AddIndices("input_shape", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias);
const auto& indices = shader.AddInput("indices", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias | ShaderUsage::UseIndicesToOffset);
const auto& scales = shader.AddInput("scales", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias | ShaderUsage::UseValueTypeAlias);
const auto& output = shader.AddOutput("output", ShaderUsage::UseUniform | ShaderUsage::UseShapeAndStride | ShaderUsage::UseValueTypeAlias);

bool is_4bit = bits_ == 4;
const std::string unpack = (is_signed_) ? "unpack4xI8" : "unpack4xU8";

shader.MainFunctionBody()
<< shader.GuardAgainstOutOfBoundsWorkgroupSizes("uniforms.output_size")
<< "let output_indices = " << output.OffsetToIndices("global_idx") << ";\n";

if (indices_rank_ > 1) {
shader.MainFunctionBody()
<< "var indices_indices = indices_indices_t(0);\n"
<< "for (var i: u32 = 0; i < " << indices_rank_ << "; i++) {\n"
<< " let index = " << output.IndicesGet("output_indices", "uniforms.gather_axis + i") << ";\n"
<< " " << indices.IndicesSet("indices_indices", "i", "index") << ";\n};\n";
} else {
shader.MainFunctionBody()
<< "let indices_indices = " << output.IndicesGet("output_indices", "uniforms.gather_axis") << ";\n";
}
shader.MainFunctionBody()
<< "var data_indices = input_shape_indices_t(0);\n"
<< "for (var i: u32 = 0; i < uniforms.gather_axis; i++) {\n"
<< " let index = " << output.IndicesGet("output_indices", "i") << ";\n "
<< x_shape.IndicesSet("data_indices", "i", "index") << ";\n};\n"
<< "var index_from_indices = " << indices.GetByIndices("indices_indices") << ";\n"
<< "if (index_from_indices < 0) { index_from_indices += " << x_shape_[gather_axis_] << ";}\n"
<< x_shape.IndicesSet("data_indices", "uniforms.gather_axis", "u32(index_from_indices)") << ";\n"
<< "for (var i = uniforms.gather_axis + 1; i < " << output_shape_.NumDimensions() << "; i++) {\n"
<< " let index = " << output.IndicesGet("output_indices", "i + " + std::to_string(indices_rank_ - 1)) << ";\n "
<< x_shape.IndicesSet("data_indices", "i", "index") << ";\n};\n"
<< " let data_offset = " << x_shape.IndicesToOffset("data_indices") << ";\n";

if (is_4bit) {
shader.MainFunctionBody()
<< " let data_index = data_offset % 8;\n"
<< " let packed_4bit_quantized_data = " << x.GetByOffset("data_offset / 8") << ";\n"
<< " let packed_8bit_quantized_data = (packed_4bit_quantized_data >> (4 * (data_index % 2))) & 0x0f0f0f0f;\n"
<< " let quantized_data_vec = " << unpack << "(u32(packed_8bit_quantized_data));\n"
<< " var quantized_data = quantized_data_vec[data_index / 2];\n";
if (is_signed_) {
shader.MainFunctionBody()
<< " if((quantized_data & 0x8) != 0) { quantized_data = quantized_data - 16 ;};\n";
}
} else {
shader.MainFunctionBody()
<< " let data_index = data_offset % 4;\n"
<< " let packed_8bit_quantized_data = " << x.GetByOffset("data_offset / 4") << ";\n"
<< " let quantized_data_vec = " << unpack << "(u32(packed_8bit_quantized_data));\n"
<< " var quantized_data = quantized_data_vec[data_index];\n";
}

shader.MainFunctionBody()
<< " var scale_indices = data_indices;\n"
<< " let quantize_axis_index = " << scales.IndicesGet("data_indices", "uniforms.quantize_axis") << "/ uniforms.block_size;\n "
<< scales.IndicesSet("scale_indices", "uniforms.quantize_axis", "quantize_axis_index") << ";\n"
<< " var scale = " << scales.GetByIndices("scale_indices") << ";\n";

if (!has_zeropoint_) {
const std::string default_zero_point = is_uint8_ ? is_4bit ? "input_element_t(8)" : "input_element_t(128)" : "input_element_t(0)";

Check warning on line 80 in onnxruntime/contrib_ops/webgpu/quantization/gather_block_quantized.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <string> for string [build/include_what_you_use] [4] Raw Output: onnxruntime/contrib_ops/webgpu/quantization/gather_block_quantized.cc:80: Add #include <string> for string [build/include_what_you_use] [4]
shader.MainFunctionBody()
<< " let zero_point = " << default_zero_point << ";\n";
} else {
const auto& zero_point = shader.AddInput("zero_point", ShaderUsage::None);
shader.MainFunctionBody()
<< " let zero_point_indices = scale_indices;\n"
<< " let zero_point_offset = " << scales.IndicesToOffset("zero_point_indices") << ";\n";
if (is_4bit) {
shader.MainFunctionBody()
<< " let zero_point_index = zero_point_offset % 8;\n"
<< " let packed_4bit_zero_points = " << zero_point.GetByOffset("zero_point_offset / 8") << ";\n"
<< " let packed_8bit_zero_points = (packed_4bit_zero_points >> (4 * (zero_point_index % 2))) & 0x0f0f0f0f;\n"
<< " let zero_point_vec = " << unpack << "(u32(packed_8bit_zero_points));\n"
<< " var zero_point = zero_point_vec[zero_point_index / 2];\n";
} else {
shader.MainFunctionBody()
<< " let zero_point_index = zero_point_offset % 4;\n"
<< " let packed_8bit_zero_points = " << zero_point.GetByOffset("zero_point_offset / 4") << ";\n"
<< " let zero_point_vec = " << unpack << "(u32(packed_8bit_zero_points));\n"
<< " var zero_point = zero_point_vec[zero_point_index];\n";
}
if (is_signed_) {
shader.MainFunctionBody()
<< " if((zero_point & 0x8) != 0) { zero_point = zero_point - 16 ;};\n";
}
}
shader.MainFunctionBody()
<< " let dequantized_data = (output_value_t(quantized_data) - output_value_t(zero_point)) * scale;\n "
<< output.SetByOffset("global_idx", "dequantized_data") << ";\n";

return Status::OK();
}

TensorShapeVector splice(TensorShapeVector vec, size_t start, size_t deleteCount, const TensorShapeVector toInsert = {}) {
TensorShapeVector new_vec;

for (size_t i = 0; i < vec.size(); i++) {
if (i < start) {
new_vec.push_back(vec[i]);
} else if (i == start) {
new_vec.insert(new_vec.end(), toInsert.begin(), toInsert.end());
} else if (i >= start + deleteCount) {
new_vec.push_back(vec[i]);
}
}
return new_vec;
}

Status GatherBlockQuantized::ComputeInternal(ComputeContext& context) const {
const auto* x = context.Input(0);
const auto* indices = context.Input(1);
const auto* scales = context.Input(2);
const auto* zero_points = context.Input(3);

int x_rank = static_cast<int>(x->Shape().NumDimensions());
int64_t x_dtype = x->GetElementType();
bool is_signed = x_dtype == ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8 || x_dtype == ONNX_TENSOR_ELEMENT_DATA_TYPE_INT4;
bool is_int8 = x_dtype == ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8 || x_dtype == ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8;

std::optional<Tensor> data_representation_4bit;
std::optional<Tensor> zero_points_representation_4bit;
if (bits_ == 4 && is_int8) {
TensorShape data_representation_4bit_shape{x->Shape()};
MLDataType new_dtype = (x_dtype == ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8) ? DataTypeImpl::GetType<UInt4x2>() : DataTypeImpl::GetType<Int4x2>();
auto memory_info = OrtMemoryInfo{
"WebGPU_Buffer",
OrtDeviceAllocator,
OrtDevice{OrtDevice::GPU, OrtDevice::MemType::DEFAULT, OrtDevice::VendorIds::NONE, 0}};

data_representation_4bit_shape[x_rank - 1] = data_representation_4bit_shape[x_rank - 1] * 2;
data_representation_4bit.emplace(
new_dtype,
data_representation_4bit_shape,
const_cast<void*>(x->DataRaw()),
memory_info);

if (zero_points) {
TensorShape zero_points_representation_4bit_shape{zero_points->Shape()};
zero_points_representation_4bit_shape[zero_points->Shape().NumDimensions() - 1] =
zero_points_representation_4bit_shape[zero_points->Shape().NumDimensions() - 1] * 2;
zero_points_representation_4bit.emplace(
new_dtype,
zero_points_representation_4bit_shape,
const_cast<void*>(zero_points->DataRaw()),
memory_info);
}
x = data_representation_4bit.has_value() ? &data_representation_4bit.value() : x;
zero_points = zero_points_representation_4bit.has_value() ? &zero_points_representation_4bit.value() : zero_points;
}

const auto& x_shape = x->Shape();

size_t indices_rank = indices->Shape().NumDimensions();
const auto scales_shape = scales->Shape();
size_t scales_rank = scales_shape.NumDimensions();
int gather_axis = (gather_axis_ >= 0) ? gather_axis_ : gather_axis_ + x_rank;
int quantize_axis = (quantize_axis_ >= 0) ? quantize_axis_ : quantize_axis_ + x_rank;

ORT_RETURN_IF_NOT(x_shape.NumDimensions() == scales_rank,
"data and scales must have the same rank.");
for (size_t i = 0; i < x_shape.NumDimensions(); ++i) {
ORT_RETURN_IF_NOT(i == static_cast<size_t>(quantize_axis)
? (x_shape[i] * 1 + block_size_ - 1) / block_size_ == scales_shape[i]
: x_shape[i] == scales_shape[i],
"data and scales do not match shapes.");
}

TensorShape output_shape = splice(x_shape.AsShapeVector(), gather_axis, 1, indices->Shape().AsShapeVector());
int64_t output_size = output_shape.Size();
auto* output_tensor = context.Output(0, output_shape);

GatherBlockQuantizedProgram program{is_signed, is_int8, indices_rank, gather_axis, bits_, zero_points != nullptr, x_shape, output_shape};

program
.AddInputs({{x, ProgramTensorMetadataDependency::Type, ProgramInput::Flatten, (bits_ == 4) ? 8 : 4}})
.AddIndices(x_shape)
.AddInputs({{indices, ProgramTensorMetadataDependency::TypeAndRank}})
.AddInputs({{scales, ProgramTensorMetadataDependency::TypeAndRank}})
.AddOutput({output_tensor, ProgramTensorMetadataDependency::None})
.SetDispatchGroupSize((output_size + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE)
.AddUniformVariables({{static_cast<uint32_t>(output_size)}})
.AddUniformVariables({{static_cast<uint32_t>(quantize_axis)}})
.AddUniformVariables({{static_cast<uint32_t>(gather_axis)}})
.AddUniformVariables({{static_cast<uint32_t>(block_size_)}})
.CacheHint(std::to_string(gather_axis), std::to_string(quantize_axis), std::to_string(block_size_));

if (zero_points != nullptr) {
ORT_RETURN_IF_NOT(scales_shape == zero_points->Shape(),
"scales and zero_points must have the same shape.");
auto zero_points_shape = zero_points->Shape();
program.AddInputs({{zero_points, ProgramTensorMetadataDependency::None, ProgramInput::Flatten, (bits_ == 4) ? 8 : 4}});
}

return context.RunProgram(program);
}

namespace {
const std::vector<MLDataType>& GatherBlockQuantizedT1Constraint() {
static std::vector<MLDataType> types{
DataTypeImpl::GetTensorType<Int4x2>(),
DataTypeImpl::GetTensorType<UInt4x2>(),
DataTypeImpl::GetTensorType<uint8_t>()};
return types;
}
const std::vector<MLDataType>& GatherBlockQuantizedTindConstraint() {
static std::vector<MLDataType> types{

Check warning on line 226 in onnxruntime/contrib_ops/webgpu/quantization/gather_block_quantized.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <vector> for vector<> [build/include_what_you_use] [4] Raw Output: onnxruntime/contrib_ops/webgpu/quantization/gather_block_quantized.cc:226: Add #include <vector> for vector<> [build/include_what_you_use] [4]
DataTypeImpl::GetTensorType<int32_t>(),
DataTypeImpl::GetTensorType<int64_t>()};
return types;
}
} // namespace

ONNX_OPERATOR_KERNEL_EX(
GatherBlockQuantized,
kMSDomain,
1,
kWebGpuExecutionProvider,
(*KernelDefBuilder::Create())
.TypeConstraint("T1", GatherBlockQuantizedT1Constraint())
.TypeConstraint("T2", WebGpuSupportedFloatTypes())
.TypeConstraint("Tind", GatherBlockQuantizedTindConstraint()),
GatherBlockQuantized);

} // namespace webgpu
} // namespace contrib
} // namespace onnxruntime
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#pragma once

#include "core/providers/webgpu/program.h"
#include "core/providers/webgpu/webgpu_kernel.h"

namespace onnxruntime {
namespace contrib {
namespace webgpu {

using namespace onnxruntime::webgpu;

Check warning on line 13 in onnxruntime/contrib_ops/webgpu/quantization/gather_block_quantized.h

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Do not use namespace using-directives. Use using-declarations instead. [build/namespaces] [5] Raw Output: onnxruntime/contrib_ops/webgpu/quantization/gather_block_quantized.h:13: Do not use namespace using-directives. Use using-declarations instead. [build/namespaces] [5]
using onnxruntime::webgpu::ComputeContext;

class GatherBlockQuantizedProgram final : public Program<GatherBlockQuantizedProgram> {
public:
GatherBlockQuantizedProgram(const bool is_signed, const bool is_uint8, size_t indices_rank, int gather_axis, int bits, bool has_zeropoint,
TensorShape x_shape, TensorShape output_shape) : Program<GatherBlockQuantizedProgram>{"GatherBlockQuantized"},
is_signed_{is_signed},
is_uint8_{is_uint8},
indices_rank_{indices_rank},
gather_axis_{gather_axis},
bits_{bits},
has_zeropoint_{has_zeropoint},
x_shape_{x_shape},
output_shape_{output_shape} {}

Status GenerateShaderCode(ShaderHelper& sh) const override;

WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"output_size", ProgramUniformVariableDataType::Uint32},
{"quantize_axis", ProgramUniformVariableDataType::Uint32},
{"gather_axis", ProgramUniformVariableDataType::Uint32},
{"block_size", ProgramUniformVariableDataType::Uint32});

private:
bool is_signed_;
bool is_uint8_;
size_t indices_rank_;
int gather_axis_;
int bits_;
bool has_zeropoint_;
TensorShape x_shape_;
TensorShape output_shape_;
};

class GatherBlockQuantized final : public WebGpuKernel {
public:
GatherBlockQuantized(const OpKernelInfo& info) : WebGpuKernel(info) {
gather_axis_ = static_cast<int>(info.GetAttrOrDefault<int64_t>("gather_axis", 0));
block_size_ = static_cast<int>(info.GetAttrOrDefault<int64_t>("block_size", 128));
quantize_axis_ = static_cast<int>(info.GetAttrOrDefault<int64_t>("quantize_axis", 1));
bits_ = static_cast<int>(info.GetAttrOrDefault<int64_t>("bits", 4));

ORT_ENFORCE(bits_ == 4 || bits_ == 8, "'bits' must be 4 or 8.");
ORT_ENFORCE(block_size_ >= 16 && ((block_size_ - 1) & block_size_) == 0,
"'block_size' must be 2's power and not less than 16.");
}

Status ComputeInternal(ComputeContext& context) const override;

private:
int gather_axis_;
int quantize_axis_;
int block_size_;
int bits_;
};

} // namespace webgpu
} // namespace contrib
} // namespace onnxruntime
2 changes: 2 additions & 0 deletions onnxruntime/contrib_ops/webgpu/webgpu_contrib_kernels.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, Bi
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, BiasSplitGelu);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, FastGelu);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, FusedConv);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, GatherBlockQuantized);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, Gelu);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, GroupQueryAttention);
// LayerNormalization used to be a contrib op that (incorrectly) used kOnnxDomain so we need to version it
Expand All @@ -40,6 +41,7 @@ Status RegisterWebGpuContribKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, BiasAdd)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, BiasSplitGelu)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, FastGelu)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, GatherBlockQuantized)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, FusedConv)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, Gelu)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, GroupQueryAttention)>,
Expand Down
4 changes: 2 additions & 2 deletions onnxruntime/core/providers/webgpu/shader_variable.cc
Original file line number Diff line number Diff line change
Expand Up @@ -268,8 +268,8 @@ void ShaderVariableHelper::Impl(std::ostream& ss) const {
// Implementation of "fn get_{name}_by_indices"
if (usage_ & ShaderUsage::UseGetByIndices) {
if (rank_ >= 2) {
SS_APPEND(ss, "fn get_", name_, "_by_indices(indices: ", IndicesType(), ")->", ValueType(), " {\n");
SS_APPEND(ss, " return ", GetByOffset("i2o_" + name_ + "(indices)"), ";\n");
SS_APPEND(ss, "fn get_", name_, "_by_indices(indices_fnarg: ", IndicesType(), ")->", ValueType(), " {\n");
SS_APPEND(ss, " return ", GetByOffset("i2o_" + name_ + "(indices_fnarg)"), ";\n");
SS_APPEND(ss, "}\n");
}
}
Expand Down
Loading
Loading