Skip to content

Commit c52136f

Browse files
authored
Create ReLU6 Vulkan Compute Shader & Add Op To Backend Delegate (#12558)
Summary: This diff adds support for the ReLU6 activation function in the Vulkan backend of Executorch. It introduces a Vulkan compute shader for ReLU6 and integrates the new op into the Vulkan delegate for execution. This enables models using ReLU6 to run efficiently on Vulkan targets, improving compatibility and expanding the delegate’s supported op set. The shader supports both buffer and texture data formats, enabling flexible deployment across different memory storage modes and ensuring compatibility with a wide range of models and execution environments. Differential Revision: D78124607 cc @SS-JIA @manuelcandales @cbilgin
1 parent 6aa864e commit c52136f

File tree

2 files changed

+8
-0
lines changed

2 files changed

+8
-0
lines changed

backends/vulkan/runtime/graph/ops/impl/UnaryOp.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,11 @@ float get_val_or_inf(ComputeGraph& graph, const ValueRef& val, bool max) {
104104
kClampShaderName); \
105105
}
106106

107+
#define DEFINE_RELU6_FN(op_name) \
108+
void op_name(ComputeGraph& graph, const std::vector<ValueRef>& args) { \
109+
return add_unary_op_node(graph, args[0], 0, 6, args[1], kClampShaderName); \
110+
}
111+
107112
#define DEFINE_HARDSHRINK_FN(op_name) \
108113
void op_name(ComputeGraph& graph, const std::vector<ValueRef>& args) { \
109114
return add_unary_op_node( \
@@ -146,6 +151,7 @@ DEFINE_ACTIVATION_FN(tanh);
146151
DEFINE_CLAMP_FN(clamp);
147152
DEFINE_CLAMP_FN(hardtanh);
148153
DEFINE_RELU_FN(relu);
154+
DEFINE_RELU6_FN(relu6);
149155
DEFINE_HARDSHRINK_FN(hardshrink);
150156
DEFINE_ACTIVATION_FN(hardswish);
151157
DEFINE_ACTIVATION_FN(hardsigmoid);
@@ -161,6 +167,7 @@ REGISTER_OPERATORS {
161167
VK_REGISTER_OP(aten.hardtanh.default, hardtanh);
162168
VK_REGISTER_OP(aten.neg.default, neg);
163169
VK_REGISTER_OP(aten.relu.default, relu);
170+
VK_REGISTER_OP(aten.relu6.default, relu6);
164171
VK_REGISTER_OP(aten.sigmoid.default, sigmoid);
165172
VK_REGISTER_OP(aten.sin.default, sin);
166173
VK_REGISTER_OP(aten.sqrt.default, sqrt);

backends/vulkan/test/op_tests/cases.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1526,6 +1526,7 @@ def get_var_inputs():
15261526
"aten.leaky_relu.default",
15271527
"aten.round.default",
15281528
"aten.tan.default",
1529+
"aten.relu6.default",
15291530
]
15301531
)
15311532
def get_unary_ops_inputs():

0 commit comments

Comments
 (0)