Skip to content

Commit 11a8f50

Browse files
ddavis-2015veblush
andauthored
Additional const-tensor checks (CMSIS_NN) (#3229)
@tensorflow/micro Add additional const-tensor checks to the following CMSIS_NN kernels that attempt to use tensor data during the Prepare phase: FULLY_CONNECTED SVDF UNIDIRECTIONAL_SEQUENCE_LSTM Unit tests updated. bug=fixes #3228 Co-authored-by: Esun Kim <[email protected]>
1 parent c3c78cb commit 11a8f50

File tree

6 files changed

+44
-3
lines changed

6 files changed

+44
-3
lines changed

tensorflow/lite/micro/kernels/cmsis_nn/fully_connected.cc

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -151,9 +151,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
151151
data->kernel_sums = nullptr;
152152

153153
#if defined(KERNELS_OPTIMIZED_FOR_SPEED)
154-
const int8_t* filter_data = GetTensorData<const int8_t>(filter);
155-
156-
if (buf_size > 0 && filter_data != nullptr) {
154+
if (buf_size > 0 && IsConstantTensor(filter) &&
155+
(bias == nullptr || IsConstantTensor(bias))) {
156+
const int8_t* filter_data = GetTensorData<const int8_t>(filter);
157157
const int32_t input_offset = -data->reference_op_data.input_zero_point;
158158
const int32_t filter_offset =
159159
-data->reference_op_data.filter_zero_point;

tensorflow/lite/micro/kernels/cmsis_nn/svdf.cc

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -193,6 +193,7 @@ TfLiteStatus CmsisNnPrepareSvdf(TfLiteContext* context, TfLiteNode* node) {
193193

194194
if (buf_size > 0) {
195195
#if defined(KERNELS_OPTIMIZED_FOR_SPEED)
196+
TF_LITE_ENSURE(context, IsConstantTensor(weights_feature));
196197
data->kernel_sums = static_cast<int32_t*>(
197198
context->AllocatePersistentBuffer(context, buf_size));
198199

@@ -210,6 +211,12 @@ TfLiteStatus CmsisNnPrepareSvdf(TfLiteContext* context, TfLiteNode* node) {
210211
"Either KERNELS_OPTIMIZED_FOR_SIZE or KERNELS_OPTIMIZED_FOR_SPEED "
211212
"must be defined");
212213
return kTfLiteError;
214+
#endif
215+
} else {
216+
// safety first!
217+
data->kernel_sums = nullptr;
218+
#if defined(KERNELS_OPTIMIZED_FOR_SIZE)
219+
data->scratch_weight_tensor_index = -1;
213220
#endif
214221
}
215222

@@ -310,6 +317,7 @@ TfLiteStatus EvalIntegerSVDF(TfLiteContext* context, TfLiteNode* node,
310317
#if defined(KERNELS_OPTIMIZED_FOR_SPEED)
311318
ctx.buf = data.kernel_sums;
312319
#elif defined(KERNELS_OPTIMIZED_FOR_SIZE)
320+
TF_LITE_ENSURE(context, data.scratch_weight_tensor_index != -1);
313321
ctx.buf = static_cast<int32_t*>(
314322
context->GetScratchBuffer(context, data.scratch_weight_tensor_index));
315323

tensorflow/lite/micro/kernels/cmsis_nn/unidirectional_sequence_lstm.cc

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -344,6 +344,18 @@ TfLiteStatus UnidirectionalSequenceLstmPrepare(TfLiteContext* context,
344344
// All TempTfLiteTensors will be deallocated through the destructor.
345345
LstmTensors lstm_tensors(context, node);
346346
TF_LITE_ENSURE_OK(context, lstm_tensors.ValidateTensorStatus(context));
347+
// Additional validation of weights and biases.
348+
// ValidateTensorStatus() ensures no tensor is <nullptr>.
349+
for (size_t i = 1; i < 9; i++) {
350+
// check weight
351+
TF_LITE_ENSURE(context,
352+
IsConstantTensor(lstm_tensors.GetInternalTensor(i)));
353+
}
354+
for (size_t i = 12; i < 16; i++) {
355+
// check bias
356+
TF_LITE_ENSURE(context,
357+
IsConstantTensor(lstm_tensors.GetInternalTensor(i)));
358+
}
347359

348360
op_data_lstm->cell_gate_nonlinear_type = builtin_data->activation;
349361
op_data_lstm->size_info =

tensorflow/lite/micro/kernels/fully_connected_test.cc

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -307,6 +307,11 @@ TfLiteStatus ValidateFullyConnectedGoldens(
307307
TfLiteIntArray* inputs_array = IntArrayFromInts(inputs_array_data);
308308
TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data);
309309

310+
tensors[1].allocation_type = kTfLiteMmapRo;
311+
if (!null_bias) {
312+
tensors[2].allocation_type = kTfLiteMmapRo;
313+
}
314+
310315
#ifdef USE_TFLM_COMPRESSION
311316

312317
TestCompressedList<kMaxTensors> tcl;

tensorflow/lite/micro/kernels/svdf_test.cc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -498,6 +498,8 @@ void ValidateSVDFGoldens(const int batch_size, const int num_units,
498498
int outputs_array_data[] = {1, 5};
499499
TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data);
500500

501+
tensors[1].allocation_type = kTfLiteMmapRo;
502+
501503
const TFLMRegistration registration = Register_SVDF();
502504
micro::KernelRunner runner(registration, tensors, tensor_count, inputs_array,
503505
outputs_array, &params);

tensorflow/lite/micro/kernels/unidirectional_sequence_lstm_test.cc

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,18 @@ namespace {
2929

3030
constexpr int kLstmMaxNumInputOutputTensors = 24 + 1;
3131

32+
// Set weights and biases to be const-tensors
33+
void SetConstTensors(TfLiteTensor* tensors) {
34+
for (size_t i = 1; i < 9; i++) {
35+
// weights
36+
tensors[i].allocation_type = kTfLiteMmapRo;
37+
}
38+
for (size_t i = 12; i < 16; i++) {
39+
// biases
40+
tensors[i].allocation_type = kTfLiteMmapRo;
41+
}
42+
}
43+
3244
// Validate the output result array with golden values
3345
template <typename T>
3446
void ValidateResultGoldens(const T* golden, const T* output_data,
@@ -49,6 +61,7 @@ void TestUnidirectionalLSTMInteger(
4961
LstmNodeContent<ActivationType, WeightType, BiasType, CellType, batch_size,
5062
time_steps, input_dimension, state_dimension>&
5163
node_contents) {
64+
SetConstTensors(node_contents.GetTensors());
5265
const TFLMRegistration registration = Register_UNIDIRECTIONAL_SEQUENCE_LSTM();
5366
auto buildin_data = node_contents.BuiltinData();
5467
micro::KernelRunner runner(
@@ -101,6 +114,7 @@ void TestUnidirectionalLSTMFloat(
101114
const float hidden_state_tolerance, const float cell_state_tolerance,
102115
LstmNodeContent<float, float, float, float, batch_size, time_steps,
103116
input_dimension, state_dimension>& node_contents) {
117+
SetConstTensors(node_contents.GetTensors());
104118
const TFLMRegistration registration = Register_UNIDIRECTIONAL_SEQUENCE_LSTM();
105119
auto buildin_data = node_contents.BuiltinData();
106120
micro::KernelRunner runner(

0 commit comments

Comments
 (0)