Skip to content

Commit 05b7b9f

Browse files
ethansfngEthan Ngmanuelcandales
authored
Add support for strongly typed op_quantized_relu (#13345)
Differential Revision: D80117641 --------- Co-authored-by: Ethan Ng <[email protected]> Co-authored-by: Manuel Candales <[email protected]>
1 parent d262061 commit 05b7b9f

9 files changed

+314
-17
lines changed

backends/cadence/aot/functions.yaml

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -219,6 +219,16 @@
219219
- arg_meta: null
220220
kernel_name: impl::reference::quantized_relu_per_tensor_out
221221

222+
- func: cadence::quantized_relu_asym8s_asym8s.per_tensor_out(Tensor X, int X_zero_point, int out_zero_point, int out_multiplier, int out_shift, *, Tensor(a!) out) -> Tensor(a!)
223+
kernels:
224+
- arg_meta: null
225+
kernel_name: impl::reference::quantized_relu_asym8s_asym8s_per_tensor_out
226+
227+
- func: cadence::quantized_relu_asym8u_asym8u.per_tensor_out(Tensor X, int X_zero_point, int out_zero_point, int out_multiplier, int out_shift, *, Tensor(a!) out) -> Tensor(a!)
228+
kernels:
229+
- arg_meta: null
230+
kernel_name: impl::reference::quantized_relu_asym8u_asym8u_per_tensor_out
231+
222232
- func: cadence::quantized_matmul.out(Tensor X, int X_zero_point, Tensor Y, int Y_zero_point, Tensor? bias, int out_multiplier, int out_shift, int out_zero_point, bool transposed, *, Tensor(a!) out) -> Tensor(a!)
223233
kernels:
224234
- arg_meta: null

backends/cadence/aot/functions_hifi.yaml

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -339,6 +339,16 @@
339339
- arg_meta: null
340340
kernel_name: cadence::impl::HiFi::quantized_relu_per_tensor_out
341341

342+
- func: cadence::quantized_relu_asym8s_asym8s.per_tensor_out(Tensor X, int X_zero_point, int out_zero_point, int out_multiplier, int out_shift, *, Tensor(a!) out) -> Tensor(a!)
343+
kernels:
344+
- arg_meta: null
345+
kernel_name: cadence::impl::HiFi::quantized_relu_asym8s_asym8s_per_tensor_out
346+
347+
- func: cadence::quantized_relu_asym8u_asym8u.per_tensor_out(Tensor X, int X_zero_point, int out_zero_point, int out_multiplier, int out_shift, *, Tensor(a!) out) -> Tensor(a!)
348+
kernels:
349+
- arg_meta: null
350+
kernel_name: cadence::impl::HiFi::quantized_relu_asym8u_asym8u_per_tensor_out
351+
342352
- func: cadence::quantized_matmul.out(Tensor X, int X_zero_point, Tensor Y, int Y_zero_point, Tensor? bias, int out_multiplier, int out_shift, int out_zero_point, bool transposed, *, Tensor(a!) out) -> Tensor(a!)
343353
kernels:
344354
- arg_meta: null

backends/cadence/aot/ops_registrations.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -232,6 +232,20 @@
232232
"quantized_relu.per_tensor_out(Tensor X, int X_zero_point, int out_zero_point, int out_multiplier, "
233233
"int out_shift, *, Tensor(a!) out) -> Tensor(a!)"
234234
)
235+
lib.define(
236+
"quantized_relu_asym8s_asym8s.per_tensor(Tensor X, int X_zero_point, int out_zero_point, int out_multiplier, int out_shift) -> Tensor"
237+
)
238+
lib.define(
239+
"quantized_relu_asym8s_asym8s.per_tensor_out(Tensor X, int X_zero_point, int out_zero_point, int out_multiplier, "
240+
"int out_shift, *, Tensor(a!) out) -> Tensor(a!)"
241+
)
242+
lib.define(
243+
"quantized_relu_asym8u_asym8u.per_tensor(Tensor X, int X_zero_point, int out_zero_point, int out_multiplier, int out_shift) -> Tensor"
244+
)
245+
lib.define(
246+
"quantized_relu_asym8u_asym8u.per_tensor_out(Tensor X, int X_zero_point, int out_zero_point, int out_multiplier, "
247+
"int out_shift, *, Tensor(a!) out) -> Tensor(a!)"
248+
)
235249
lib.define(
236250
"quantized_add.out(Tensor X, Tensor X_scale, Tensor X_zero_point, Tensor Y, Tensor Y_scale, "
237251
"Tensor Y_zero_point, float out_scale, int out_zero_point, *, Tensor(a!) out) -> Tensor(a!)"
@@ -770,6 +784,28 @@ def quantized_relu_per_tensor_meta(
770784
return input.new_empty(input.size(), dtype=input.dtype)
771785

772786

787+
@register_fake("cadence::quantized_relu_asym8s_asym8s.per_tensor")
788+
def quantized_relu_asym8s_asym8s_per_tensor_meta(
789+
input: torch.Tensor,
790+
in_zero_point: int,
791+
out_zero_point: int,
792+
out_multiplier: int,
793+
out_shift: int,
794+
) -> torch.Tensor:
795+
return input.new_empty(input.size(), dtype=input.dtype)
796+
797+
798+
@register_fake("cadence::quantized_relu_asym8u_asym8u.per_tensor")
799+
def quantized_relu_asym8u_asym8u_per_tensor_meta(
800+
input: torch.Tensor,
801+
in_zero_point: int,
802+
out_zero_point: int,
803+
out_multiplier: int,
804+
out_shift: int,
805+
) -> torch.Tensor:
806+
return input.new_empty(input.size(), dtype=input.dtype)
807+
808+
773809
@register_fake("cadence::fully_connected")
774810
def fully_connected_meta(
775811
src: torch.Tensor,

backends/cadence/aot/tests/test_type_dispatch_passes.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,3 +137,51 @@ def test_mixed_types_error(self) -> None:
137137
with self.assertRaises(RuntimeError) as context:
138138
cast(PassResult, p(gm)).graph_module
139139
self.assertIn("Unsupported input types", str(context.exception))
140+
141+
def test_int8_dispatch_quantized_relu(self) -> None:
142+
"""Test int8 input should dispatch to asym8s_asym8s variant for quantized_relu"""
143+
x = torch.randint(-128, 127, (2, 3), dtype=torch.int8)
144+
gm = single_op_builder(
145+
placeholders=(x,),
146+
op=exir_ops.edge.cadence.quantized_relu.per_tensor,
147+
args=(x, 0, 0, 1, 0),
148+
)
149+
p = CompileTimeTypeDispatchPass()
150+
gm = cast(PassResult, p(gm)).graph_module
151+
# Original op should be replaced
152+
self.assertEqual(
153+
count_node(gm, exir_ops.edge.cadence.quantized_relu.per_tensor),
154+
0,
155+
)
156+
# Should be replaced with int8 specific variant
157+
self.assertEqual(
158+
count_node(
159+
gm,
160+
exir_ops.edge.cadence.quantized_relu_asym8s_asym8s.per_tensor,
161+
),
162+
1,
163+
)
164+
165+
def test_uint8_dispatch_quantized_relu(self) -> None:
166+
"""Test uint8 input should dispatch to asym8u_asym8u variant for quantized_relu"""
167+
x = torch.randint(0, 255, (2, 3), dtype=torch.uint8)
168+
gm = single_op_builder(
169+
placeholders=(x,),
170+
op=exir_ops.edge.cadence.quantized_relu.per_tensor,
171+
args=(x, 0, 0, 1, 0),
172+
)
173+
p = CompileTimeTypeDispatchPass()
174+
gm = cast(PassResult, p(gm)).graph_module
175+
# Original op should be replaced
176+
self.assertEqual(
177+
count_node(gm, exir_ops.edge.cadence.quantized_relu.per_tensor),
178+
0,
179+
)
180+
# Should be replaced with uint8 specific variant
181+
self.assertEqual(
182+
count_node(
183+
gm,
184+
exir_ops.edge.cadence.quantized_relu_asym8u_asym8u.per_tensor,
185+
),
186+
1,
187+
)

backends/cadence/aot/type_dispatch.py

Lines changed: 40 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -23,40 +23,63 @@ class CompileTimeTypeDispatchPass(ExportPass):
2323
Replaces generic ops with ops that have explicit types.
2424
"""
2525

26-
_TYPE_DISPATCH_MAP: dict[tuple[torch.dtype, torch.dtype], str] = {
26+
_BINARY_TYPE_DISPATCH_MAP: dict[tuple[torch.dtype, torch.dtype], str] = {
2727
(torch.int8, torch.int8): "asym8sxasym8s_asym8s",
2828
(torch.uint8, torch.uint8): "asym8uxasym8u_asym8u",
2929
}
3030

31-
_SUPPORTED_OPS: dict[OpOverload, str] = {
31+
_UNARY_TYPE_DISPATCH_MAP: dict[torch.dtype, str] = {
32+
torch.int8: "asym8s_asym8s",
33+
torch.uint8: "asym8u_asym8u",
34+
}
35+
36+
_BINARY_SUPPORTED_OPS: dict[OpOverload, str] = {
3237
exir_ops.edge.cadence.quantized_fully_connected.per_tensor: "quantized_fully_connected",
3338
exir_ops.edge.cadence.quantized_linear.per_tensor: "quantized_linear",
3439
}
3540

41+
_SUPPORTED_UNARY_OPS: dict[OpOverload, str] = {
42+
exir_ops.edge.cadence.quantized_relu.per_tensor: "quantized_relu",
43+
}
44+
3645
def call_operator(
3746
self,
3847
op: OpOverload,
3948
args: tuple[Argument, ...],
4049
kwargs: dict[str, Argument],
4150
meta: NodeMetadata,
4251
) -> ProxyValue:
43-
if op not in self._SUPPORTED_OPS:
44-
return super().call_operator(op, args, kwargs, meta)
52+
if op in self._BINARY_SUPPORTED_OPS:
53+
# pyre-ignore[16]: None has no attribute `to_tensor`.
54+
input_dtype = args[0].to_tensor().dtype
55+
weight_dtype = args[1].to_tensor().dtype
56+
dtype_pair = (input_dtype, weight_dtype)
57+
58+
if dtype_pair not in self._BINARY_TYPE_DISPATCH_MAP:
59+
raise RuntimeError(
60+
f"Unsupported input types for {op}: {input_dtype} and {weight_dtype}"
61+
)
62+
63+
base_op_name = self._BINARY_SUPPORTED_OPS[op]
64+
type_suffix = self._BINARY_TYPE_DISPATCH_MAP[dtype_pair]
65+
66+
typed_op_name = f"{base_op_name}_{type_suffix}"
67+
typed_op = getattr(exir_ops.edge.cadence, typed_op_name).per_tensor
68+
69+
return super().call_operator(typed_op, args, kwargs, meta)
70+
71+
elif op in self._SUPPORTED_UNARY_OPS:
72+
input_dtype = args[0].to_tensor().dtype
4573

46-
# pyre-ignore[16]: None has no attribute `to_tensor`.
47-
input_dtype = args[0].to_tensor().dtype
48-
weight_dtype = args[1].to_tensor().dtype
49-
dtype_pair = (input_dtype, weight_dtype)
74+
if input_dtype not in self._UNARY_TYPE_DISPATCH_MAP:
75+
raise RuntimeError(f"Unsupported input type for {op}: {input_dtype}")
5076

51-
if dtype_pair not in self._TYPE_DISPATCH_MAP:
52-
raise RuntimeError(
53-
f"Unsupported input types for {op}: {input_dtype} and {weight_dtype}"
54-
)
77+
base_op_name = self._SUPPORTED_UNARY_OPS[op]
78+
type_suffix = self._UNARY_TYPE_DISPATCH_MAP[input_dtype]
5579

56-
base_op_name = self._SUPPORTED_OPS[op]
57-
type_suffix = self._TYPE_DISPATCH_MAP[dtype_pair]
80+
typed_op_name = f"{base_op_name}_{type_suffix}"
81+
typed_op = getattr(exir_ops.edge.cadence, typed_op_name).per_tensor
5882

59-
typed_op_name = f"{base_op_name}_{type_suffix}"
60-
typed_op = getattr(exir_ops.edge.cadence, typed_op_name).per_tensor
83+
return super().call_operator(typed_op, args, kwargs, meta)
6184

62-
return super().call_operator(typed_op, args, kwargs, meta)
85+
return super().call_operator(op, args, kwargs, meta)
Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <executorch/backends/cadence/hifi/kernels/kernels.h>
10+
#include <executorch/runtime/kernel/kernel_includes.h>
11+
#include <xa_nnlib_kernels_api.h>
12+
13+
namespace cadence {
14+
namespace impl {
15+
namespace HiFi {
16+
namespace native {
17+
18+
using ::executorch::aten::Tensor;
19+
using ::executorch::runtime::KernelRuntimeContext;
20+
21+
void quantized_relu_asym8s_asym8s_per_tensor_out(
22+
KernelRuntimeContext& ctx,
23+
const Tensor& input,
24+
const int64_t in_zero_point,
25+
const int64_t out_zero_point,
26+
const int64_t out_multiplier,
27+
const int64_t out_shift,
28+
Tensor& output) {
29+
const int8_t* __restrict__ input_data = input.const_data_ptr<int8_t>();
30+
int8_t* __restrict__ output_data = output.mutable_data_ptr<int8_t>();
31+
32+
const int32_t out_multipler_int32 = static_cast<int32_t>(out_multiplier);
33+
const int32_t out_shift_int32 = static_cast<int32_t>(out_shift);
34+
35+
const int32_t ret = xa_nn_vec_relu_asym8s_asym8s(
36+
output_data,
37+
input_data,
38+
in_zero_point,
39+
out_multipler_int32,
40+
out_shift_int32,
41+
out_zero_point,
42+
-128,
43+
127,
44+
input.numel());
45+
ET_DCHECK_MSG(
46+
ret == 0, "HiFi quantized_relu_asym8s_asym8s_per_tensor failed");
47+
}
48+
49+
} // namespace native
50+
} // namespace HiFi
51+
} // namespace impl
52+
} // namespace cadence
Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <executorch/backends/cadence/hifi/kernels/kernels.h>
10+
#include <executorch/runtime/kernel/kernel_includes.h>
11+
#include <xa_nnlib_kernels_api.h>
12+
13+
namespace cadence {
14+
namespace impl {
15+
namespace HiFi {
16+
namespace native {
17+
18+
using ::executorch::aten::Tensor;
19+
using ::executorch::runtime::KernelRuntimeContext;
20+
21+
void quantized_relu_asym8u_asym8u_per_tensor_out(
22+
KernelRuntimeContext& ctx,
23+
const Tensor& input,
24+
const int64_t in_zero_point,
25+
const int64_t out_zero_point,
26+
const int64_t out_multiplier,
27+
const int64_t out_shift,
28+
Tensor& output) {
29+
const uint8_t* __restrict__ input_data = input.const_data_ptr<uint8_t>();
30+
uint8_t* __restrict__ output_data = output.mutable_data_ptr<uint8_t>();
31+
32+
const int32_t out_multipler_int32 = static_cast<int32_t>(out_multiplier);
33+
const int32_t out_shift_int32 = static_cast<int32_t>(out_shift);
34+
35+
const int32_t ret = xa_nn_vec_relu_asym8u_asym8u(
36+
output_data,
37+
input_data,
38+
in_zero_point,
39+
out_multipler_int32,
40+
out_shift_int32,
41+
_out_zero_point,
42+
0,
43+
255,
44+
input.numel());
45+
ET_DCHECK_MSG(
46+
ret == 0, "HiFi quantized_relu_asym8u_asym8u_per_tensor failed");
47+
}
48+
49+
} // namespace native
50+
} // namespace HiFi
51+
} // namespace impl
52+
} // namespace cadence

backends/cadence/hifi/operators/targets.bzl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,8 @@ OPERATORS = [
7373
"quantized_linear_asym8uxasym8u_asym8u_per_tensor_out",
7474
"quantized_matmul_out",
7575
"quantized_relu_out",
76+
"quantized_relu_asym8s_asym8s_per_tensor_out",
77+
"quantized_relu_asym8u_asym8u_per_tensor_out",
7678
"quantize_per_tensor",
7779
"remainder",
7880
"rsqrt",

0 commit comments

Comments
 (0)