-
Notifications
You must be signed in to change notification settings - Fork 3.5k
add webgpu support for GatherBlockQuantized #25413
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR adds WebGPU support for the GatherBlockQuantized operation by implementing a complete WebGPU kernel. The implementation includes shader code generation, tensor handling for quantized data, and proper integration with the WebGPU execution provider.
- Implements WebGPU kernel for GatherBlockQuantized operation with support for 4-bit and 8-bit quantized data
- Adds comprehensive test coverage with WebGPU-specific test execution paths
- Fixes shader variable naming conflict to prevent compilation issues
Reviewed Changes
Copilot reviewed 5 out of 5 changed files in this pull request and generated 1 comment.
Show a summary per file
File | Description |
---|---|
gather_block_quantized_op_test.cc | Updates test framework to support WebGPU execution and adds device data testing |
shader_variable.cc | Fixes parameter naming conflict in shader function generation |
webgpu_contrib_kernels.cc | Registers the new GatherBlockQuantized kernel with WebGPU provider |
gather_block_quantized.h | Defines the WebGPU kernel class and shader program interface |
gather_block_quantized.cc | Implements the complete WebGPU kernel with shader generation and computation logic |
Comments suppressed due to low confidence (2)
onnxruntime/contrib_ops/webgpu/quantization/gather_block_quantized.cc:114
- [nitpick] The function name 'splice' is ambiguous and doesn't clearly indicate its purpose of modifying tensor shape vectors. Consider renaming to 'ModifyTensorShape' or 'InsertTensorDimensions' to better reflect its functionality.
TensorShapeVector splice(TensorShapeVector vec, size_t start, size_t deleteCount, const TensorShapeVector toInsert = {}) {
onnxruntime/contrib_ops/webgpu/quantization/gather_block_quantized.cc:140
- The variable name 'is_int8' is misleading since it returns true for both INT8 and UINT8 types. Consider renaming to 'is_8bit' or 'is_byte_type' to accurately reflect that it checks for 8-bit data types regardless of signedness.
bool is_int8 = x_dtype == ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8 || x_dtype == ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8;
onnxruntime/contrib_ops/webgpu/quantization/gather_block_quantized.cc
Outdated
Show resolved
Hide resolved
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can commit the suggested changes from lintrunner.
Hi there! We haven't cut the release branch for this version yet, so I'm removing the |
add webgpu support for GatherBlockQuantized
add webgpu support for GatherBlockQuantized
add webgpu support for GatherBlockQuantized