Skip to content

Commit f953921

Browse files
author
sidart
committed
Summary: Initial CMSIS-NN custom kernels port (Take #2)
Test Plan: Reviewers: Subscribers: Tasks: Tags:
1 parent b440e82 commit f953921

File tree

8 files changed

+248
-4
lines changed

8 files changed

+248
-4
lines changed

CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -530,6 +530,7 @@ endif()
530530

531531
if(EXECUTORCH_BUILD_CORTEX_M)
532532
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backends/cortex_m)
533+
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backends/cortex_m/cmsis-nn/ops)
533534
endif()
534535

535536
if(EXECUTORCH_BUILD_DEVTOOLS)
Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import torch
8+
from executorch.exir.dialects._ops import (
9+
ops as exir_ops,
10+
) # To provide the implementation of the operators
11+
from torch.library import impl, Library, register_fake
12+
13+
# New operator library with a custom namespace to allow fusion etc.
14+
lib = Library("cortex_m", "DEF")
15+
16+
###
17+
# add.Tensor
18+
###
19+
20+
lib.define("aten_add_tensor(Tensor self, Tensor other, ScalarType dtype, Tensor(a!) out) -> Tensor(a!)")
21+
22+
@impl(lib, "aten_add_tensor", "CompositeExplicitAutograd")
23+
def aten_add_tensor_impl(input1, input2, dtype, out):
24+
return exir_ops.edge.cortex_m.aten_add_tensor.default(input1, input2, dtype, dtype)
25+
26+
27+
###
28+
# add.out
29+
###
30+
31+
lib.define(
32+
"add.out(Tensor input1, Tensor input2, ScalarType dtype, Tensor(a!) out) -> Tensor(a!)"
33+
)
34+
35+
@impl(lib, "add.out", "CompositeExplicitAutograd")
36+
def add_out_impl(
37+
input1: torch.Tensor,
38+
input2: torch.Tensor,
39+
dtype: torch.dtype,
40+
out: torch.Tensor,
41+
) -> torch.Tensor:
42+
"""
43+
The implementation of cmsis-nn add.out.
44+
"""
45+
46+
return exir_ops.edge.cortex_m.add.default(
47+
input1, input2, dtype, dtype
48+
)

backends/cortex_m/cmsis-nn/cmsis.yaml

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
- op: aten::add.out
8+
kernels:
9+
- arg_meta: null
10+
kernel_name: cortex_m::aten_add_tensor
11+
12+
- op: aten::_softmax.out
13+
kernels:
14+
- arg_meta: null
15+
kernel_name: cortex_m::aten_softmax
Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
cmake_minimum_required(VERSION 3.19)
8+
9+
set(CMAKE_EXPORT_COMPILE_COMMANDS ON)
10+
if(NOT CMAKE_CXX_STANDARD)
11+
set(CMAKE_CXX_STANDARD 17)
12+
endif()
13+
14+
# Source root directory for executorch.
15+
if(NOT EXECUTORCH_ROOT)
16+
set(EXECUTORCH_ROOT ${CMAKE_CURRENT_SOURCE_DIR}/../../../../)
17+
endif()
18+
19+
include(${EXECUTORCH_ROOT}/tools/cmake/Utils.cmake)
20+
include(${EXECUTORCH_ROOT}/tools/cmake/Codegen.cmake)
21+
22+
set(EXECUTORCH_ENABLE_LOGGING ON CACHE BOOL "Enable ExecuTorch logging")
23+
set(EXECUTORCH_LOG_LEVEL "DEBUG" CACHE STRING "ExecuTorch log level")
24+
25+
# Cortex-M CMSIS ops that are needed to run this model.
26+
set(_cortex_m_kernels_cmsis__srcs
27+
"${EXECUTORCH_ROOT}/backends/cortex_m/cmsis-nn/ops/op_aten_add_tensor.cpp"
28+
"${EXECUTORCH_ROOT}/backends/cortex_m/cmsis-nn/ops/op_aten_softmax.cpp"
29+
)
30+
31+
# Let files say "include <executorch/path/to/header.h>".
32+
set(_common_include_directories ${EXECUTORCH_ROOT}/..
33+
${EXECUTORCH_ROOT}/runtime/core/portable_type/c10)
34+
35+
add_library(cortex_m_cmsis_kernels ${_cortex_m_kernels_cmsis__srcs})
36+
target_link_libraries(cortex_m_cmsis_kernels PRIVATE executorch)
37+
target_compile_options(cortex_m_cmsis_kernels PUBLIC ${_common_compile_options})
38+
39+
# Generate C++ bindings to register kernels into both PyTorch (for AOT) and
40+
# Executorch (for runtime). Here select all ops in functions.yaml
41+
gen_selected_ops(
42+
LIB_NAME "cortex_m_cmsis_nn_ops_lib" OPS_SCHEMA_YAML
43+
"${CMAKE_CURRENT_LIST_DIR}/../cmsis.yaml" "" ""
44+
)
45+
generate_bindings_for_kernels(
46+
LIB_NAME "cortex_m_cmsis_nn_ops_lib" FUNCTIONS_YAML
47+
${CMAKE_CURRENT_SOURCE_DIR}/../cmsis.yaml
48+
)
49+
message("Generated files ${gen_command_sources}")
50+
51+
gen_operators_lib(
52+
LIB_NAME "cortex_m_cmsis_nn_ops_lib" KERNEL_LIBS cortex_m_cmsis_kernels DEPS executorch
53+
)
54+
55+
install(
56+
TARGETS cortex_m_cmsis_kernels cortex_m_cmsis_nn_ops_lib
57+
DESTINATION lib
58+
PUBLIC_HEADER DESTINATION include/executorch/backends/cortex_m/cmsis-nn/ops/
59+
)
Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
#include <executorch/runtime/kernel/kernel_includes.h>
2+
#include <executorch/runtime/core/portable_type/tensor.h> // for torch::executor::Tensor
3+
#include <executorch/runtime/core/portable_type/scalar.h> // for torch::executor::Scalar
4+
#include <iostream>
5+
6+
namespace cortex_m {
7+
namespace native {
8+
9+
using Tensor = executorch::aten::Tensor;
10+
using ScalarType = executorch::aten::ScalarType;
11+
using Scalar = executorch::aten::Scalar;
12+
using KernelRuntimeContext = torch::executor::KernelRuntimeContext;
13+
14+
torch::executor::Tensor& aten_add_tensor(
15+
torch::executor::KernelRuntimeContext& ctx,
16+
const torch::executor::Tensor& input1,
17+
const torch::executor::Tensor& input2,
18+
const torch::executor::Scalar& alpha,
19+
torch::executor::Tensor& out) {
20+
// Your CMSIS-NN optimized implementation here
21+
// Return 'out' tensor as per Executorch kernel signature
22+
std::cout << "add_out kernel called" << std::endl;
23+
ET_LOG(Info, "xxxxxxxxxx add_out kernel called");
24+
25+
assert(false);
26+
assert(true);
27+
return out;
28+
}
29+
30+
torch::executor::Tensor& add_out(
31+
torch::executor::KernelRuntimeContext& ctx,
32+
const torch::executor::Tensor& input1,
33+
const torch::executor::Tensor& input2,
34+
const torch::executor::Scalar& alpha,
35+
torch::executor::Tensor& out) {
36+
std::cout << "add_out kernel called" << std::endl;
37+
ET_LOG(Info, "xxxxxxxxxx add_out kernel called");
38+
39+
// Ensure input is char type
40+
ET_CHECK_MSG(
41+
input1.scalar_type() == ScalarType::Char,
42+
"input1.scalar_type() %" PRId8 " is not char type",
43+
static_cast<int8_t>(input1.scalar_type()));
44+
45+
ET_CHECK_MSG(
46+
input2.scalar_type() == ScalarType::Char,
47+
"input2.scalar_type() %" PRId8 " is not char type",
48+
static_cast<int8_t>(input2.scalar_type()));
49+
50+
// Check output dtype is float
51+
ET_CHECK_MSG(
52+
out.scalar_type() == ScalarType::Float,
53+
"out.scalar_type() %" PRId8 " is not float",
54+
static_cast<int8_t>(out.scalar_type()));
55+
56+
// Check dtype is int8 (Char)
57+
/*ET_CHECK_MSG(
58+
dtype == ScalarType::Char,
59+
"dtype %" PRId8 " is not int8 (Char)",
60+
static_cast<int8_t>(dtype));*/
61+
62+
assert(false);
63+
64+
return out;
65+
}
66+
67+
} // namespace native
68+
} // namespace cortex_m
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
#include <executorch/runtime/kernel/kernel_includes.h>
2+
#include <executorch/runtime/core/portable_type/tensor.h> // for torch::executor::Tensor
3+
#include <executorch/runtime/core/portable_type/scalar.h> // for torch::executor::Scalar
4+
#include <iostream>
5+
6+
namespace cortex_m {
7+
namespace native {
8+
9+
using Tensor = executorch::aten::Tensor;
10+
using ScalarType = executorch::aten::ScalarType;
11+
using Scalar = executorch::aten::Scalar;
12+
using KernelRuntimeContext = torch::executor::KernelRuntimeContext;
13+
14+
torch::executor::Tensor& aten_softmax(
15+
torch::executor::KernelRuntimeContext& context,
16+
const torch::executor::Tensor& self,
17+
int64_t dim,
18+
bool half_to_float,
19+
torch::executor::Tensor& out) {
20+
// Your CMSIS-NN optimized implementation here
21+
// Return 'out' tensor as per Executorch kernel signature
22+
//std::cout << "softmax kernel called" << std::endl;
23+
ET_LOG(Info, "xxxxxxxxxx softmax kernel called");
24+
25+
//assert(false);
26+
//assert(true);
27+
return out;
28+
}
29+
30+
} // namespace native
31+
} // namespace cortex_m

examples/arm/executor_runner/CMakeLists.txt

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -539,6 +539,18 @@ set_property(
539539
PROPERTY IMPORTED_LOCATION
540540
"${ET_BUILD_DIR_PATH}/backends/cortex_m/libcortex_m_kernels.a"
541541
)
542+
add_library(cortex_m_cmsis_nn_ops_lib STATIC IMPORTED)
543+
set_property(
544+
TARGET cortex_m_cmsis_nn_ops_lib
545+
PROPERTY IMPORTED_LOCATION
546+
"${ET_BUILD_DIR_PATH}/backends/cortex_m/cmsis-nn/ops/libcortex_m_cmsis_nn_ops_lib.a"
547+
)
548+
add_library(cortex_m_cmsis_kernels STATIC IMPORTED)
549+
set_property(
550+
TARGET cortex_m_cmsis_kernels
551+
PROPERTY IMPORTED_LOCATION
552+
"${ET_BUILD_DIR_PATH}/backends/cortex_m/cmsis-nn/ops/libcortex_m_cmsis_kernels.a"
553+
)
542554
add_library(extension_runner_util STATIC IMPORTED)
543555
set_property(
544556
TARGET extension_runner_util
@@ -580,11 +592,13 @@ list(APPEND arm_executor_runner_link
580592
"-Wl,--whole-archive"
581593
executorch_delegate_ethos_u
582594
cortex_m_ops_lib
595+
cortex_m_cmsis_nn_ops_lib
583596
quantized_ops_lib
584597
portable_ops_lib
585598
quantized_kernels
586-
cortex_m_kernels
587599
portable_kernels
600+
cortex_m_kernels
601+
cortex_m_cmsis_kernels
588602
"-Wl,--no-whole-archive"
589603
-Xlinker -Map=arm_executor_runner.map
590604
)

runtime/kernel/operator_registry.cpp

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -81,17 +81,20 @@ Error register_kernels_internal(const Span<const Kernel> kernels) {
8181
et_pal_get_shared_library_name(kernels.data());
8282

8383
for (const auto& kernel : kernels) {
84+
bool duplicate = false;
8485
// Linear search. This is fine if the number of kernels is small.
8586
for (size_t i = 0; i < num_registered_kernels; i++) {
8687
Kernel k = registered_kernels[i];
8788
if (strcmp(kernel.name_, k.name_) == 0 &&
8889
kernel.kernel_key_ == k.kernel_key_) {
89-
ET_LOG(Error, "Re-registering %s, from %s", k.name_, lib_name);
90+
ET_LOG(Error, "! Re-registering %s, from %s", k.name_, lib_name);
9091
ET_LOG_KERNEL_KEY(k.kernel_key_);
91-
return Error::RegistrationAlreadyRegistered;
92+
//return Error::RegistrationAlreadyRegistered;
93+
duplicate = true;
9294
}
9395
}
94-
registered_kernels[num_registered_kernels++] = kernel;
96+
if (!duplicate)
97+
registered_kernels[num_registered_kernels++] = kernel;
9598
}
9699
ET_LOG(
97100
Debug,
@@ -238,9 +241,12 @@ Result<OpFunction> get_op_function_from_registry(
238241
return err;
239242
}
240243
KernelKey kernel_key = KernelKey(key_string.data());
244+
//ET_LOG(Debug, "get_op_function_from_registry: name %s", name);
245+
ET_LOG_TENSOR_META(meta_list);
241246

242247
int32_t fallback_idx = -1;
243248
for (size_t idx = 0; idx < num_registered_kernels; idx++) {
249+
ET_LOG(Info, "get_op_function_from_registry Checking kernel %s", registered_kernels[idx].name_);
244250
if (strcmp(registered_kernels[idx].name_, name) == 0) {
245251
if (registered_kernels[idx].kernel_key_ == kernel_key) {
246252
return registered_kernels[idx].op_;
@@ -250,7 +256,9 @@ Result<OpFunction> get_op_function_from_registry(
250256
}
251257
}
252258
}
259+
253260
if (fallback_idx != -1) {
261+
ET_LOG(Info, "get_op_function_from_registry: fallback kernel %s", registered_kernels[fallback_idx].name_);
254262
return registered_kernels[fallback_idx].op_;
255263
}
256264
ET_LOG(Error, "kernel '%s' not found.", name);

0 commit comments

Comments
 (0)