Skip to content

Commit 40aedca

Browse files
authored
Merge branch 'main' into bugfix/issue-3075-3076-3077-non-existent-buffer
2 parents 5ce923e + ecf58d7 commit 40aedca

File tree

11 files changed

+1096
-145
lines changed

11 files changed

+1096
-145
lines changed

tensorflow/lite/micro/kernels/BUILD

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,20 @@ tflm_cc_library(
7979
],
8080
)
8181

82+
tflm_cc_library(
83+
name = "decode_test_helpers",
84+
hdrs = [
85+
"decode_test_helpers.h",
86+
],
87+
deps = [
88+
":kernel_runner",
89+
":micro_ops",
90+
"//tensorflow/lite/c:common",
91+
"//tensorflow/lite/micro:test_helpers",
92+
"//tensorflow/lite/micro/testing:micro_test",
93+
],
94+
)
95+
8296
tflm_cc_library(
8397
name = "decompress",
8498
srcs = [
@@ -239,6 +253,7 @@ tflm_kernel_cc_library(
239253
"decode.cc",
240254
"decode_state.cc",
241255
"decode_state_lut.cc",
256+
"decode_state_prune.cc",
242257
"depth_to_space.cc",
243258
"depthwise_conv.cc",
244259
"depthwise_conv_common.cc",
@@ -332,6 +347,7 @@ tflm_kernel_cc_library(
332347
"conv.h",
333348
"decode_state.h",
334349
"decode_state_lut.h",
350+
"decode_state_prune.h",
335351
"depthwise_conv.h",
336352
"dequantize.h",
337353
"ethosu.h",
@@ -648,12 +664,29 @@ tflm_cc_test(
648664
],
649665
)
650666

667+
tflm_cc_test(
668+
name = "decode_state_prune_test",
669+
srcs = [
670+
"decode_state_prune_test.cc",
671+
],
672+
deps = [
673+
":decode_test_helpers",
674+
":kernel_runner",
675+
"//tensorflow/lite/c:common",
676+
"//tensorflow/lite/micro:debug_log",
677+
"//tensorflow/lite/micro:op_resolvers",
678+
"//tensorflow/lite/micro:test_helpers",
679+
"//tensorflow/lite/micro/testing:micro_test",
680+
],
681+
)
682+
651683
tflm_cc_test(
652684
name = "decode_test",
653685
srcs = [
654686
"decode_test.cc",
655687
],
656688
deps = [
689+
":decode_test_helpers",
657690
":kernel_runner",
658691
"//tensorflow/lite/c:common",
659692
"//tensorflow/lite/micro:debug_log",

tensorflow/lite/micro/kernels/Makefile.inc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,7 @@ $(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/ceil_test.cc \
123123
$(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/comparisons_test.cc \
124124
$(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/concatenation_test.cc \
125125
$(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/cumsum_test.cc \
126+
$(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/decode_state_prune_test.cc \
126127
$(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/decode_test.cc \
127128
$(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/depth_to_space_test.cc \
128129
$(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/depthwise_conv_test.cc \

tensorflow/lite/micro/kernels/decode.cc

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
6363
break;
6464
}
6565

66+
TF_LITE_ENSURE(context, IsConstantTensor(input));
67+
TF_LITE_ENSURE(context, IsConstantTensor(ancillary));
68+
6669
if (DecodeState::Version(*ancillary) != 1) {
6770
MicroPrintf("version %u != 1", DecodeState::Version(*ancillary));
6871
status = kTfLiteError;
@@ -75,6 +78,10 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
7578
dsp = DecodeState::CreateDecodeStateLUT(
7679
context, micro_context->GetAlternateProfiler());
7780
break;
81+
case DecodeState::kDcmTypePrune:
82+
dsp = DecodeState::CreateDecodeStatePrune(
83+
context, micro_context->GetAlternateProfiler());
84+
break;
7885
case DecodeState::kDcmTypeCustom:
7986
MicroPrintf("Custom decode type not yet supported");
8087
break;

tensorflow/lite/micro/kernels/decode_state.cc

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ limitations under the License.
1616
#include "tensorflow/lite/micro/kernels/decode_state.h"
1717

1818
#include "tensorflow/lite/micro/kernels/decode_state_lut.h"
19+
#include "tensorflow/lite/micro/kernels/decode_state_prune.h"
1920
#include "tensorflow/lite/micro/micro_context.h"
2021

2122
namespace tflite {
@@ -33,4 +34,17 @@ DecodeState* DecodeState::CreateDecodeStateLUT(
3334
return dsp;
3435
}
3536

37+
DecodeState* DecodeState::CreateDecodeStatePrune(
38+
const TfLiteContext* context, MicroProfilerInterface* profiler) {
39+
MicroContext* const micro_context = GetMicroContext(context);
40+
void* buffer =
41+
micro_context->AllocatePersistentBuffer(sizeof(DecodeStatePrune));
42+
if (buffer == nullptr) {
43+
return nullptr;
44+
}
45+
DecodeState* dsp = new (buffer) DecodeStatePrune(context, profiler);
46+
47+
return dsp;
48+
}
49+
3650
} // namespace tflite

tensorflow/lite/micro/kernels/decode_state.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,8 @@ class DecodeState {
4343

4444
static DecodeState* CreateDecodeStateLUT(const TfLiteContext* context,
4545
MicroProfilerInterface* profiler);
46+
static DecodeState* CreateDecodeStatePrune(const TfLiteContext* context,
47+
MicroProfilerInterface* profiler);
4648

4749
static uint8_t Type(const TfLiteTensor& ancillary) {
4850
return GetTensorData<uint8_t>(&ancillary)[kDcmDecodeTypeOffset];
@@ -66,6 +68,7 @@ class DecodeState {
6668
// Decode Common Metadata constants
6769
public:
6870
static constexpr uint8_t kDcmTypeLUT = 0;
71+
static constexpr uint8_t kDcmTypePrune = 2;
6972
static constexpr uint8_t kDcmTypeCustom = 127;
7073

7174
static constexpr size_t kDcmSizeInBytes = 16;
Lines changed: 206 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,206 @@
1+
/* Copyright 2025 The TensorFlow Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License.
14+
==============================================================================*/
15+
16+
#include "tensorflow/lite/micro/kernels/decode_state_prune.h"
17+
18+
#include <algorithm>
19+
#include <cstddef>
20+
21+
#include "tensorflow/lite/kernels/internal/compatibility.h"
22+
#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
23+
#include "tensorflow/lite/kernels/kernel_util.h"
24+
#include "tensorflow/lite/micro/micro_context.h"
25+
#include "tensorflow/lite/micro/micro_log.h"
26+
#include "tensorflow/lite/micro/micro_profiler.h"
27+
28+
namespace tflite {
29+
30+
TfLiteStatus DecodeStatePrune::Setup(const TfLiteTensor& input,
31+
const TfLiteTensor& ancillary,
32+
const TfLiteTensor& output) {
33+
const uint8_t* const ancillary_data = GetTensorData<uint8_t>(&ancillary);
34+
if (ancillary_data[kDcmVersionOffset] != 1) {
35+
MicroPrintf("unsupported version %u", ancillary_data[kDcmVersionOffset]);
36+
return kTfLiteError;
37+
}
38+
39+
// resolve num_channels_, use_alternate_axis_, and zero points
40+
if (output.quantization.type == kTfLiteAffineQuantization &&
41+
output.quantization.params != nullptr) {
42+
const TfLiteAffineQuantization* quantization =
43+
reinterpret_cast<TfLiteAffineQuantization*>(output.quantization.params);
44+
num_channels_ = quantization->scale->size;
45+
if ((quantization->quantized_dimension == output.dims->size - 1) &&
46+
num_channels_ > 1) {
47+
use_alternate_axis_ = true;
48+
} else if (quantization->quantized_dimension != 0) {
49+
MicroPrintf("unsupported quantization axis %u",
50+
quantization->quantized_dimension);
51+
return kTfLiteError;
52+
}
53+
54+
TFLITE_DCHECK(num_channels_ ==
55+
static_cast<size_t>(quantization->zero_point->size));
56+
bool has_non_zero_zp =
57+
std::any_of(quantization->zero_point->data,
58+
quantization->zero_point->data + num_channels_,
59+
[](int zp) { return zp != 0; });
60+
61+
if (output.type != kTfLiteInt8) {
62+
// make sure all zero points are 0 (zero)
63+
TF_LITE_ENSURE_MSG(const_cast<TfLiteContext*>(context_),
64+
has_non_zero_zp == false,
65+
"All zero-points must be zero");
66+
}
67+
68+
if (num_channels_ > 1 && has_non_zero_zp) {
69+
// copy zero points
70+
MicroContext* micro_context = GetMicroContext(context_);
71+
const size_t bufsize = num_channels_ * sizeof(*zero_points_);
72+
zero_points_ = static_cast<decltype(zero_points_)>(
73+
micro_context->AllocatePersistentBuffer(bufsize));
74+
if (zero_points_ == nullptr) {
75+
MicroPrintf("unable to allocate zero_points_");
76+
return kTfLiteError;
77+
}
78+
std::copy_n(quantization->zero_point->data, num_channels_, zero_points_);
79+
} else {
80+
single_zero_point_ = quantization->zero_point->data[0];
81+
}
82+
}
83+
84+
compressed_indices_ = GetTensorData<uint8_t>(&input);
85+
count_indices_ = NumElements(&output);
86+
elements_per_channel_ =
87+
use_alternate_axis_ ? 1 : count_indices_ / num_channels_;
88+
value_table_ = &ancillary_data[kDcmSizeInBytes];
89+
90+
return kTfLiteOk;
91+
}
92+
93+
TfLiteStatus DecodeStatePrune::Decode(const TfLiteEvalTensor& input,
94+
const TfLiteEvalTensor& ancillary,
95+
const TfLiteEvalTensor& output) {
96+
void* const buffer = const_cast<void*>(micro::GetTensorData<void>(&output));
97+
TFLITE_DCHECK(buffer != nullptr);
98+
99+
switch (output.type) {
100+
case kTfLiteBool:
101+
DecompressToBuffer<int8_t>(buffer);
102+
break;
103+
case kTfLiteFloat32:
104+
DecompressToBuffer<int32_t>(buffer);
105+
break;
106+
case kTfLiteInt8:
107+
if (num_channels_ > 1 && zero_points_ != nullptr) {
108+
DecompressToBufferPerChannelInt8(buffer);
109+
} else {
110+
DecompressToBuffer<int8_t>(buffer);
111+
}
112+
break;
113+
case kTfLiteInt16:
114+
DecompressToBuffer<int16_t>(buffer);
115+
break;
116+
case kTfLiteInt32:
117+
DecompressToBuffer<int32_t>(buffer);
118+
break;
119+
case kTfLiteInt64:
120+
DecompressToBuffer<int64_t>(buffer);
121+
break;
122+
default:
123+
MicroPrintf("unsupported tensor type %s", TfLiteTypeGetName(output.type));
124+
return kTfLiteError;
125+
}
126+
127+
return kTfLiteOk;
128+
}
129+
130+
template <typename T>
131+
void DecodeStatePrune::DecompressToBuffer(void* vp) {
132+
ScopedMicroProfiler scoped_profiler(__func__, micro_profiler_);
133+
134+
T* buffer = static_cast<T*>(vp);
135+
const T* value_table = static_cast<const T*>(value_table_);
136+
const size_t max_count = count_indices_;
137+
const uint8_t* const indices = compressed_indices_;
138+
139+
for (size_t index = 0; index < max_count; index++) {
140+
size_t shift = ~index & 0b111;
141+
size_t is_not_zp = (indices[index >> 3] >> shift) & 0b1;
142+
143+
if (is_not_zp) {
144+
*buffer++ = *value_table++;
145+
} else {
146+
*buffer++ = single_zero_point_;
147+
}
148+
}
149+
}
150+
151+
void DecodeStatePrune::DecompressToBufferPerChannelInt8(void* vp) {
152+
TFLITE_DCHECK(zero_points_ != nullptr);
153+
ScopedMicroProfiler scoped_profiler(__func__, micro_profiler_);
154+
155+
int8_t* buffer = static_cast<int8_t*>(vp);
156+
size_t current_offset = 0;
157+
const uint8_t* const indices = compressed_indices_;
158+
const int8_t* value_table = static_cast<const int8_t*>(value_table_);
159+
160+
if (use_alternate_axis_) {
161+
const size_t max_channels = num_channels_;
162+
size_t count = count_indices_;
163+
164+
while (count > 0) {
165+
for (size_t channel = 0; channel < max_channels; channel++) {
166+
const int8_t zp = zero_points_[channel];
167+
size_t shift = ~current_offset & 0b111;
168+
size_t is_not_zp = (indices[current_offset >> 3] >> shift) & 0b1;
169+
170+
if (is_not_zp) {
171+
*buffer++ = *value_table++;
172+
} else {
173+
*buffer++ = zp;
174+
}
175+
current_offset++;
176+
}
177+
count -= max_channels;
178+
}
179+
} else {
180+
const size_t max_count = elements_per_channel_;
181+
182+
for (size_t channel = 0; channel < num_channels_; channel++) {
183+
size_t count = max_count;
184+
const int8_t zp = zero_points_[channel];
185+
186+
while (count-- > 0) {
187+
size_t shift = ~current_offset & 0b111;
188+
size_t is_not_zp = (indices[current_offset >> 3] >> shift) & 0b1;
189+
190+
if (is_not_zp) {
191+
*buffer++ = *value_table++;
192+
} else {
193+
*buffer++ = zp;
194+
}
195+
current_offset++;
196+
}
197+
}
198+
}
199+
}
200+
201+
template void DecodeStatePrune::DecompressToBuffer<int8_t>(void*);
202+
template void DecodeStatePrune::DecompressToBuffer<int16_t>(void*);
203+
template void DecodeStatePrune::DecompressToBuffer<int32_t>(void*);
204+
template void DecodeStatePrune::DecompressToBuffer<int64_t>(void*);
205+
206+
} // namespace tflite

0 commit comments

Comments
 (0)