Skip to content

Commit eb253fd

Browse files
author
sidart
committed
Initial draft CMSIS-NN integration (WIP)
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags:
1 parent f6cc262 commit eb253fd

File tree

7 files changed

+209
-12
lines changed

7 files changed

+209
-12
lines changed

backends/cortex_m/CMakeLists.txt

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,11 @@ include(${EXECUTORCH_ROOT}/tools/cmake/Codegen.cmake)
2525

2626
# Cortex-M ops kernel sources
2727
set(_cortex_m_kernels__srcs
28+
${CMAKE_CURRENT_SOURCE_DIR}/ops/op_add.cpp
29+
${CMAKE_CURRENT_SOURCE_DIR}/ops/op_aten_add_tensor.cpp
30+
${CMAKE_CURRENT_SOURCE_DIR}/ops/op_softmax.cpp
2831
${CMAKE_CURRENT_SOURCE_DIR}/ops/op_quantize_per_tensor.cpp
2932
${CMAKE_CURRENT_SOURCE_DIR}/ops/op_dequantize_per_tensor.cpp
30-
${CMAKE_CURRENT_SOURCE_DIR}/ops/op_add.cpp
3133
)
3234

3335
# Generate C++ bindings to register kernels into Executorch (for runtime).

backends/cortex_m/ops/op_add.cpp

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
#include <executorch/runtime/kernel/kernel_includes.h>
2-
#include <cinttypes>
2+
#include <iostream>
3+
34
namespace cortex_m {
45
namespace native {
56

@@ -13,7 +14,9 @@ Tensor& add_out(
1314
const Tensor& input2,
1415
const ScalarType dtype,
1516
Tensor& out) {
16-
17+
std::cout << "add_out kernel called" << std::endl;
18+
ET_LOG(Info, "xxxxxxxxxx add_out kernel called");
19+
1720
// Ensure input is char type
1821
ET_CHECK_MSG(
1922
input1.scalar_type() == ScalarType::Char,
@@ -37,6 +40,7 @@ Tensor& add_out(
3740
"dtype %" PRId8 " is not int8 (Char)",
3841
static_cast<int8_t>(dtype));
3942

43+
assert(false);
4044

4145
return out;
4246
}
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
#include <executorch/runtime/kernel/kernel_includes.h>
2+
#include <iostream>
3+
4+
namespace cortex_m {
5+
namespace native {
6+
7+
using Tensor = executorch::aten::Tensor;
8+
using ScalarType = executorch::aten::ScalarType;
9+
using KernelRuntimeContext = torch::executor::KernelRuntimeContext;
10+
11+
Tensor& aten_add_tensor(
12+
KernelRuntimeContext& ctx,
13+
const Tensor& self,
14+
const Tensor& other,
15+
const ScalarType dtype,
16+
Tensor& out) {
17+
ET_LOG(Info, "xxxxxxxxxx aten_add_tensor kernel called");
18+
19+
// Ensure input is char type
20+
ET_CHECK_MSG(
21+
self.scalar_type() == ScalarType::Char,
22+
"self.scalar_type() %" PRId8 " is not char type",
23+
static_cast<int8_t>(self.scalar_type()));
24+
25+
ET_CHECK_MSG(
26+
other.scalar_type() == ScalarType::Char,
27+
"other.scalar_type() %" PRId8 " is not char type",
28+
static_cast<int8_t>(other.scalar_type()));
29+
30+
// Check dtype is int8 (Char)
31+
ET_CHECK_MSG(
32+
dtype == ScalarType::Char,
33+
"dtype %" PRId8 " is not int8 (Char)",
34+
static_cast<int8_t>(dtype));
35+
36+
// Example: element-wise add self and other into out
37+
// (Assuming Tensor has data() and size() methods)
38+
const int8_t* self_data = self.const_data_ptr<int8_t>();
39+
const int8_t* other_data = other.const_data_ptr<int8_t>();
40+
int8_t* out_data = out.mutable_data_ptr<int8_t>();
41+
size_t numel = self.numel(); // or self.size() if that's the API
42+
for (size_t i = 0; i < numel; ++i) {
43+
out_data[i] = self_data[i] + other_data[i];
44+
}
45+
return out;
46+
}
47+
48+
} // namespace native
49+
} // namespace cortex_m

backends/cortex_m/ops/op_softmax.cpp

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
#include <executorch/runtime/kernel/kernel_includes.h>
2+
#include <iostream>
3+
4+
namespace cortex_m {
5+
namespace native {
6+
7+
using Tensor = executorch::aten::Tensor;
8+
using ScalarType = executorch::aten::ScalarType;
9+
using KernelRuntimeContext = torch::executor::KernelRuntimeContext;
10+
11+
Tensor& softmax_out(
12+
KernelRuntimeContext& ctx,
13+
const Tensor& self,
14+
int64_t dim,
15+
bool half_to_float,
16+
Tensor& out) {
17+
// Your optimized implementation here
18+
// Fill 'out' with the result and return it
19+
std::cout << "xxxxxxxxxx softmax_out kernel called" << std::endl;
20+
std::cout.flush();
21+
ET_LOG(Error, "xxxxxxxxxx softmax_out kernel called");
22+
23+
return out;
24+
}
25+
26+
Tensor softmax(
27+
KernelRuntimeContext& ctx,
28+
const Tensor& self,
29+
int64_t dim,
30+
bool half_to_float) {
31+
std::cout << "xxxxxxxxxx softmax_default kernel called" << std::endl;
32+
std::cout.flush();
33+
ET_LOG(Error, "xxxxxxxxxx softmax_default kernel called");
34+
return self;
35+
}
36+
37+
} // namespace native
38+
} // namespace cortex_m

backends/cortex_m/ops/operators.py

Lines changed: 70 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,60 @@
1313
# New operator library with a custom namespace to allow fusion etc.
1414
lib = Library("cortex_m", "DEF")
1515

16+
# Import these for the cadence function signatures.
17+
import executorch.backends.cortex_m.cortex_m_ops_lib # noqa: F401
18+
19+
###
20+
# add.Tensor
21+
###
22+
23+
lib.define(
24+
"add.Tensor(Tensor self, Tensor other, ScalarType dtype) -> (Tensor Z)"
25+
)
26+
27+
lib.define(
28+
"add_Tensor.out(Tensor self, Tensor other, ScalarType dtype, Tensor(a!) out) -> Tensor(a!)"
29+
)
30+
31+
@impl(lib, "add.Tensor", "CompositeExplicitAutograd")
32+
def aten_add_tensor_impl(
33+
input1: torch.Tensor,
34+
input2: torch.Tensor,
35+
dtype: torch.dtype,
36+
out: torch.Tensor,
37+
) -> torch.Tensor:
38+
"""
39+
The implementation of aten add.Tensor.
40+
"""
41+
return exir_ops.edge.cortex_m.add.Tensor(input1, input2, dtype)
42+
43+
###
44+
# add.out
45+
###
46+
47+
lib.define(
48+
"add(Tensor input1, Tensor input2, ScalarType dtype) -> (Tensor Z)"
49+
)
50+
51+
lib.define(
52+
"add.out(Tensor input1, Tensor input2, ScalarType dtype, Tensor(a!) out) -> Tensor(a!)"
53+
)
54+
55+
@impl(lib, "add.out", "CompositeExplicitAutograd")
56+
def add_out_impl(
57+
input1: torch.Tensor,
58+
input2: torch.Tensor,
59+
dtype: torch.dtype,
60+
out: torch.Tensor,
61+
) -> torch.Tensor:
62+
"""
63+
The implementation of cmsis-nn add.out.
64+
"""
65+
66+
return exir_ops.edge.cortex_m.add.default(
67+
input1, input2, dtype, dtype
68+
)
69+
1670
###
1771
# dequantize_per_tensor
1872
###
@@ -25,7 +79,6 @@
2579
"quantize_per_tensor.out(Tensor input, float scale, int zero_point, int quant_min, int quant_max, ScalarType dtype, *, Tensor(a!) out) -> Tensor(a!)"
2680
)
2781

28-
2982
@register_fake("cortex_m::quantize_per_tensor")
3083
def quantize_per_tensor_meta(
3184
input: torch.Tensor,
@@ -37,7 +90,6 @@ def quantize_per_tensor_meta(
3790
) -> torch.Tensor:
3891
return torch.empty_like(input, dtype=dtype)
3992

40-
4193
@impl(lib, "quantize_per_tensor", "CompositeExplicitAutograd")
4294
def quantize_per_tensor_impl(
4395
input: torch.Tensor,
@@ -96,3 +148,19 @@ def dequantize_per_tensor_impl(
96148
return exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default(
97149
input, scale, zero_point, quant_min, quant_max, dtype
98150
)
151+
152+
lib.define(
153+
"softmax(Tensor self, int dim, bool half_to_float) -> Tensor"
154+
)
155+
lib.define(
156+
"softmax.out(Tensor self, int dim, bool half_to_float, *, Tensor(a!) out) -> Tensor(a!)"
157+
)
158+
@impl(lib, "softmax", "CompositeExplicitAutograd")
159+
def softmax_impl(self: torch.Tensor, dim: int, half_to_float: bool) -> torch.Tensor:
160+
# Call your custom edge op or fallback
161+
# return exir_ops.edge.cortex_m.softmax(self, dim, half_to_float)
162+
# ctx = get_kernel_ctx() # gets KernelRuntimeContext*
163+
return {}
164+
@impl(lib, "softmax.out", "CompositeExplicitAutograd")
165+
def softmax_out_impl(self: torch.Tensor, dim: int, half_to_float: bool, out: torch.Tensor) -> torch.Tensor:
166+
return exir_ops.edge.cortex_m.softmax_out(self, dim, half_to_float, out)

backends/cortex_m/ops/operators.yaml

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,26 @@
1616
- arg_meta: null
1717
kernel_name: cortex_m::dequantize_per_tensor_out
1818

19-
- func: cortex_m::add.out(Tensor a, Tensor b, Scalar alpha, *, Tensor(a!) out) -> Tensor(a!)
19+
- func: cortex_m::add.out(Tensor a, Tensor b, ScalarType dtype, *, Tensor(a!) out) -> Tensor(a!)
2020
variants: function
2121
kernels:
2222
- arg_meta: null
2323
kernel_name: cortex_m::add_out
24+
25+
- func: cortex_m::add.Tensor(Tensor self, Tensor other, ScalarType dtype, *, Tensor(a!) out) -> Tensor(a!)
26+
variants: function
27+
kernels:
28+
- arg_meta: null
29+
kernel_name: cortex_m::aten_add_tensor
30+
31+
- func: cortex_m::softmax(Tensor self, int dim, bool half_to_float) -> Tensor
32+
variants: function
33+
kernels:
34+
- arg_meta: null
35+
kernel_name: cortex_m::softmax
36+
37+
- func: cortex_m::softmax.out(Tensor self, int dim, bool half_to_float, *, Tensor(a!) out) -> Tensor(a!)
38+
variants: function
39+
kernels:
40+
- arg_meta: null
41+
kernel_name: cortex_m::softmax_out

backends/cortex_m/passes/replace_quant_nodes_pass.py

Lines changed: 24 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,22 @@ def _is_qualified_int8_node(args) -> bool:
3131
def __init__(self):
3232
super().__init__()
3333
self.op_replacements = {
34+
exir_ops.edge.add: {
35+
"new_target": exir_ops.edge.cortex_m.add,
36+
"qualifier": lambda args: True,
37+
},
38+
exir_ops.edge.aten.add.Tensor: {
39+
"new_target": exir_ops.edge.cortex_m.add.Tensor,
40+
"qualifier": lambda args: True,
41+
},
42+
exir_ops.edge.aten._softmax.out: {
43+
"new_target": exir_ops.edge.cortex_m.softmax.out,
44+
"qualifier": lambda args: True,
45+
},
46+
exir_ops.edge.aten._softmax.default: {
47+
"new_target": exir_ops.edge.cortex_m.softmax, # or .softmax if you have an out variant
48+
"qualifier": lambda args: True,
49+
},
3450
exir_ops.edge.quantized_decomposed.quantize_per_tensor.default: {
3551
"new_target": exir_ops.edge.cortex_m.quantize_per_tensor.default,
3652
"qualifier": self._is_qualified_int8_node,
@@ -51,12 +67,14 @@ def call_operator(
5167
assert isinstance(
5268
op, EdgeOpOverload
5369
), "Op must be an EdgeOpOverload. Run this pass after to_edge()."
70+
print(f"[ReplaceQuantNodesPass] Operator called: {op}, Args: {args}")
5471

55-
if op in self.op_replacements and self.op_replacements[op]["qualifier"](args):
72+
if op in self.op_replacements and self.op_replacements[op]["qualifier"](args):
73+
print(f"[ReplaceQuantNodesPass] Replacing {op} with {self.op_replacements[op]['new_target']}")
5674
return super().call_operator(
57-
self.op_replacements[op]["new_target"],
58-
args,
59-
kwargs,
60-
meta,
61-
)
75+
self.op_replacements[op]["new_target"],
76+
args,
77+
kwargs,
78+
meta,
79+
)
6280
return super().call_operator(op, args, kwargs, meta)

0 commit comments

Comments
 (0)