Skip to content

Commit 649ad2f

Browse files
author
sidart
committed
Summary: Initial CMSIS-NN custom kernels port (Take #2)
Test Plan: Reviewers: Subscribers: Tasks: Tags:
1 parent b440e82 commit 649ad2f

File tree

10 files changed

+390
-19
lines changed

10 files changed

+390
-19
lines changed

CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -530,6 +530,7 @@ endif()
530530

531531
if(EXECUTORCH_BUILD_CORTEX_M)
532532
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backends/cortex_m)
533+
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backends/cortex_m/cmsis-nn/ops)
533534
endif()
534535

535536
if(EXECUTORCH_BUILD_DEVTOOLS)
Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import torch
8+
from executorch.exir.dialects._ops import (
9+
ops as exir_ops,
10+
) # To provide the implementation of the operators
11+
from torch.library import impl, Library, register_fake
12+
13+
# New operator library with a custom namespace to allow fusion etc.
14+
lib = Library("cortex_m", "DEF")
15+
16+
###
17+
# add.Tensor
18+
###
19+
20+
lib.define("aten_add_tensor(Tensor self, Tensor other, ScalarType dtype, Tensor(a!) out) -> Tensor(a!)")
21+
22+
@impl(lib, "aten_add_tensor", "CompositeExplicitAutograd")
23+
def aten_add_tensor_impl(input1, input2, dtype, out):
24+
return exir_ops.edge.cortex_m.aten_add_tensor.default(input1, input2, dtype, dtype)
25+
26+
27+
###
28+
# add.out
29+
###
30+
31+
lib.define(
32+
"add.out(Tensor input1, Tensor input2, ScalarType dtype, Tensor(a!) out) -> Tensor(a!)"
33+
)
34+
35+
@impl(lib, "add.out", "CompositeExplicitAutograd")
36+
def add_out_impl(
37+
input1: torch.Tensor,
38+
input2: torch.Tensor,
39+
dtype: torch.dtype,
40+
out: torch.Tensor,
41+
) -> torch.Tensor:
42+
"""
43+
The implementation of cmsis-nn add.out.
44+
"""
45+
46+
return exir_ops.edge.cortex_m.add.default(
47+
input1, input2, dtype, dtype
48+
)

backends/cortex_m/cmsis-nn/cmsis.yaml

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
- op: aten::add.out
8+
kernels:
9+
- arg_meta: null
10+
kernel_name: cortex_m::aten_add_tensor
11+
12+
- op: aten::_softmax.out
13+
kernels:
14+
- arg_meta: null
15+
kernel_name: cortex_m::aten_softmax
Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
cmake_minimum_required(VERSION 3.19)
8+
9+
set(CMAKE_EXPORT_COMPILE_COMMANDS ON)
10+
if(NOT CMAKE_CXX_STANDARD)
11+
set(CMAKE_CXX_STANDARD 17)
12+
endif()
13+
set(CMAKE_VERBOSE_MAKEFILE ON)
14+
15+
# Source root directory for executorch.
16+
if(NOT EXECUTORCH_ROOT)
17+
set(EXECUTORCH_ROOT ${CMAKE_CURRENT_SOURCE_DIR}/../../../../)
18+
endif()
19+
20+
include(${EXECUTORCH_ROOT}/tools/cmake/Utils.cmake)
21+
include(${EXECUTORCH_ROOT}/tools/cmake/Codegen.cmake)
22+
23+
set(EXECUTORCH_ENABLE_LOGGING ON CACHE BOOL "Enable ExecuTorch logging")
24+
set(EXECUTORCH_LOG_LEVEL "DEBUG" CACHE STRING "ExecuTorch log level")
25+
26+
# Path to CMSIS-NN root - adjust as needed
27+
set(CMSIS_NN_ROOT /home/sidart/working/CMSIS-NN)
28+
29+
# Cortex-M CMSIS ops sources
30+
set(_cortex_m_kernels_cmsis__srcs
31+
"${EXECUTORCH_ROOT}/backends/cortex_m/cmsis-nn/ops/op_aten_add_tensor.cpp"
32+
"${EXECUTORCH_ROOT}/backends/cortex_m/cmsis-nn/ops/op_aten_softmax.cpp"
33+
)
34+
35+
# Common include directories
36+
set(_common_include_directories
37+
${EXECUTORCH_ROOT}/..
38+
${EXECUTORCH_ROOT}/runtime/core/portable_type/c10
39+
${CMSIS_NN_ROOT}/Include
40+
${CMSIS_NN_ROOT} # For any CMake or config includes
41+
)
42+
43+
# Import CMSIS-NN static library as a target
44+
add_library(cmsis_nn STATIC IMPORTED)
45+
set_target_properties(cmsis_nn PROPERTIES
46+
IMPORTED_LOCATION "${CMSIS_NN_ROOT}/build/libcmsis-nn.a"
47+
INTERFACE_INCLUDE_DIRECTORIES "${CMSIS_NN_ROOT}/Include"
48+
)
49+
50+
# Build cortex_m_cmsis_kernels static library
51+
add_library(cortex_m_cmsis_kernels ${_cortex_m_kernels_cmsis__srcs})
52+
53+
# Include directories for cortex_m_cmsis_kernels
54+
target_include_directories(cortex_m_cmsis_kernels
55+
PRIVATE
56+
${_common_include_directories}
57+
)
58+
59+
# Link libraries: executorch and CMSIS-NN imported target
60+
target_link_libraries(cortex_m_cmsis_kernels
61+
PRIVATE
62+
cmsis_nn
63+
executorch
64+
)
65+
66+
# Generate C++ bindings for kernels and operators
67+
gen_selected_ops(
68+
LIB_NAME "cortex_m_cmsis_nn_ops_lib" OPS_SCHEMA_YAML
69+
"${CMAKE_CURRENT_LIST_DIR}/../cmsis.yaml" "" ""
70+
)
71+
generate_bindings_for_kernels(
72+
LIB_NAME "cortex_m_cmsis_nn_ops_lib" FUNCTIONS_YAML
73+
${CMAKE_CURRENT_SOURCE_DIR}/../cmsis.yaml
74+
)
75+
76+
gen_operators_lib(
77+
LIB_NAME "cortex_m_cmsis_nn_ops_lib" KERNEL_LIBS cortex_m_cmsis_kernels DEPS executorch
78+
)
79+
set(CMAKE_EXE_LINKER_FLAGS "-Wl,--gc-sections")
80+
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -ffunction-sections -fdata-sections")
81+
82+
# Install targets and headers
83+
install(
84+
TARGETS cortex_m_cmsis_kernels cortex_m_cmsis_nn_ops_lib
85+
DESTINATION lib
86+
PUBLIC_HEADER DESTINATION include/executorch/backends/cortex_m/cmsis-nn/ops/
87+
)
Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
#include <executorch/runtime/kernel/kernel_includes.h>
2+
#include <executorch/runtime/core/portable_type/tensor.h> // for torch::executor::Tensor
3+
#include <executorch/runtime/core/portable_type/scalar.h> // for torch::executor::Scalar
4+
#include <iostream>
5+
6+
namespace cortex_m {
7+
namespace native {
8+
9+
using Tensor = executorch::aten::Tensor;
10+
using ScalarType = executorch::aten::ScalarType;
11+
using Scalar = executorch::aten::Scalar;
12+
using KernelRuntimeContext = torch::executor::KernelRuntimeContext;
13+
14+
torch::executor::Tensor& aten_add_tensor(
15+
torch::executor::KernelRuntimeContext& ctx,
16+
const torch::executor::Tensor& input1,
17+
const torch::executor::Tensor& input2,
18+
const torch::executor::Scalar& alpha,
19+
torch::executor::Tensor& out) {
20+
// Your CMSIS-NN optimized implementation here
21+
// Return 'out' tensor as per Executorch kernel signature
22+
std::cout << "add_out kernel called" << std::endl;
23+
ET_LOG(Info, "xxxxxxxxxx add_out kernel called");
24+
25+
assert(false);
26+
assert(true);
27+
return out;
28+
}
29+
30+
torch::executor::Tensor& add_out(
31+
torch::executor::KernelRuntimeContext& ctx,
32+
const torch::executor::Tensor& input1,
33+
const torch::executor::Tensor& input2,
34+
const torch::executor::Scalar& alpha,
35+
torch::executor::Tensor& out) {
36+
std::cout << "add_out kernel called" << std::endl;
37+
ET_LOG(Info, "xxxxxxxxxx add_out kernel called");
38+
39+
// Ensure input is char type
40+
ET_CHECK_MSG(
41+
input1.scalar_type() == ScalarType::Char,
42+
"input1.scalar_type() %" PRId8 " is not char type",
43+
static_cast<int8_t>(input1.scalar_type()));
44+
45+
ET_CHECK_MSG(
46+
input2.scalar_type() == ScalarType::Char,
47+
"input2.scalar_type() %" PRId8 " is not char type",
48+
static_cast<int8_t>(input2.scalar_type()));
49+
50+
// Check output dtype is float
51+
ET_CHECK_MSG(
52+
out.scalar_type() == ScalarType::Float,
53+
"out.scalar_type() %" PRId8 " is not float",
54+
static_cast<int8_t>(out.scalar_type()));
55+
56+
// Check dtype is int8 (Char)
57+
/*ET_CHECK_MSG(
58+
dtype == ScalarType::Char,
59+
"dtype %" PRId8 " is not int8 (Char)",
60+
static_cast<int8_t>(dtype));*/
61+
62+
assert(false);
63+
64+
return out;
65+
}
66+
67+
} // namespace native
68+
} // namespace cortex_m
Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,116 @@
1+
#include <executorch/runtime/kernel/kernel_includes.h>
2+
#include <executorch/runtime/core/portable_type/tensor.h> // for torch::executor::Tensor
3+
#include <executorch/runtime/core/portable_type/scalar.h> // for torch::executor::Scalar
4+
5+
#include <vector>
6+
#include <algorithm>
7+
#include <cmath>
8+
#include <cstdint>
9+
10+
extern "C" {
11+
#include "Include/arm_nnfunctions.h"
12+
}
13+
14+
namespace cortex_m {
15+
namespace native {
16+
17+
using Tensor = torch::executor::Tensor;
18+
using KernelRuntimeContext = torch::executor::KernelRuntimeContext;
19+
20+
// Determine quantization scale from fp32 data
21+
float determine_input_scale(const float* data, int size) {
22+
float min_val = *std::min_element(data, data + size);
23+
float max_val = *std::max_element(data, data + size);
24+
return (max_val - min_val) / 255.0f; // For int8 range [-128, 127]
25+
}
26+
// Quantize fp32 to int8
27+
void quantize_tensor(const float* input, int8_t* output, int size,
28+
float scale, int32_t zero_point) {
29+
for (int i = 0; i < size; i++) {
30+
int32_t quantized = std::round(input[i] / scale) + zero_point;
31+
// This ensures that the value quantized stays within the specified bounds — in this case, between -128 and 127,
32+
// which are the limits of int8_t.
33+
output[i] = std::clamp(quantized, static_cast<int32_t>(-128), static_cast<int32_t>(127));
34+
}
35+
}
36+
// Dequantize int8 to fp32
37+
void dequantize_tensor(const int8_t* input, float* output, int size,
38+
float scale, int32_t zero_point) {
39+
for (int i = 0; i < size; i++) {
40+
output[i] = (input[i] - zero_point) * scale;
41+
}
42+
}
43+
44+
// Converts a floating-point scale to CMSIS-NN fixed-point multiplier and shift
45+
// scale: the floating-point scale factor from ExecuTorch quantization
46+
// multiplier: output fixed-point multiplier (Q31 format)
47+
// shift: output left shift amount (positive means left shift)
48+
// diff_min: output minimum difference threshold (usually -128 for int8)
49+
void convert_scale_to_cmsis_params(float scale, int32_t* multiplier, int32_t* shift, int32_t* diff_min) {
50+
if (scale == 0.0f) {
51+
*multiplier = 0;
52+
*shift = 0;
53+
*diff_min = -128;
54+
return;
55+
}
56+
// Decompose scale into mantissa and exponent: scale = mantissa * 2^exponent
57+
int exponent;
58+
float mantissa = std::frexp(scale, &exponent); // mantissa in [0.5, 1)
59+
// Convert mantissa to Q31 fixed-point format
60+
int64_t q_fixed = static_cast<int64_t>(std::round(mantissa * (1ll << 31)));
61+
// Adjust multiplier and shift for CMSIS-NN
62+
*multiplier = static_cast<int32_t>(q_fixed);
63+
// CMSIS-NN expects a left shift, so negate exponent to get shift
64+
*shift = -exponent;
65+
// Typical diff_min for int8 softmax
66+
*diff_min = -128;
67+
}
68+
69+
torch::executor::Tensor& aten_softmax(
70+
KernelRuntimeContext& context,
71+
const Tensor& self,
72+
int64_t dim,
73+
bool half_to_float,
74+
Tensor& out) {
75+
76+
ET_LOG(Info, "CMSIS-NN quantized softmax kernel called");
77+
78+
// Step 1: Extract fp32 data
79+
const float* input_data_fp32 = self.data_ptr<float>();
80+
float* output_data_fp32 = out.data_ptr<float>();
81+
82+
// Step 2: Get tensor dimensions
83+
int rows = self.sizes()[0];
84+
int cols = self.sizes()[1];
85+
86+
// Step 3: Quantize input (fp32 -> int8)
87+
// Determine appropriate scale/zero_point
88+
float input_scale = determine_input_scale(input_data_fp32, rows * cols);
89+
90+
// '0' a reasonable default for symmetric quantization in int8,
91+
// especially if the input data is centered around zero else TBD
92+
int32_t input_zero_point = 0;
93+
94+
std::vector<int8_t> input_quantized(rows * cols);
95+
quantize_tensor(input_data_fp32, input_quantized.data(),
96+
rows * cols, input_scale, input_zero_point);
97+
98+
// Step 4: Convert to CMSIS-NN parameters
99+
int32_t input_mult, input_shift, diff_min;
100+
convert_scale_to_cmsis_params(input_scale, &input_mult, &input_shift, &diff_min);
101+
102+
// Step 5: Call CMSIS-NN kernel
103+
std::vector<int8_t> output_quantized(rows * cols);
104+
arm_softmax_s8(input_quantized.data(), rows, cols,
105+
input_mult, input_shift, diff_min,
106+
output_quantized.data());
107+
108+
// Step 6: Dequantize output (int8 -> fp32)
109+
dequantize_tensor(output_quantized.data(), output_data_fp32,
110+
rows * cols, input_scale, input_zero_point);
111+
112+
return out;
113+
}
114+
115+
} // namespace native
116+
} // namespace cortex_m

examples/arm/ethos-u-setup/arm-none-eabi-gcc.cmake

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ elseif(
7777
elseif(CMAKE_SYSTEM_PROCESSOR MATCHES "cortex-m4(\\+|$)"
7878
OR CMAKE_SYSTEM_PROCESSOR MATCHES "cortex-m7(\\+|$)"
7979
)
80-
set(FLOAT hard)
80+
set(FLOAT soft)
8181
set(FPU_CONFIG "fpv4-sp-d16")
8282
add_compile_options(-mfpu=${FPU_CONFIG})
8383
add_link_options(-mfpu=${FPU_CONFIG})

0 commit comments

Comments
 (0)