Skip to content

Commit a94423c

Browse files
authored
Support for DECODE operator (#3188)
* Support for DECODE operator @tensorflow/micro Add initial support for DECODE operator. Add reference implementation. Add LUT decompression support. Update op resolvers. Update Makefiles and Bazel BUILD files. Add kernel unit test. bug=fixes #3131 * update copyright * Don't use constructors with global objects (bluepill will not call them). Cleanup unit test. * Support for DECODE operator @tensorflow/micro Additional support for DECODE operator. Add Xtensa optimizations for LUT decompression. Move all Xtensa kernel source references to the Xtensa target makefile. bug=fixes #3150 * Updates to Xtensa makefiles @tensorflow/micro Reorganize Xtensa makefiles such that all references to optimized kernel sources are moved to the Xtensa target makefile. Move hifimini kernel sources to the parent directory, and rename them so they do not interfere with the target overlay mechanism of the root makefile. bug=fixes #3153 * Fix incorrect include path. Fix code style errors. * fix copyright * update generic benchmark op resolver size * Support for DECODE operator @tensorflow/micro Add reference implementation of pruning to DECODE operator. Makefile and Bazel BUILD file changes. Additional unit tests. bug=fixes #3161 * xtensa int8 single channel working * xtensa per-channel int8 normal axis working * WIP * working xtensa optimizations * Add negative unit test * Support for DECODE operator @tensorflow/micro Add optimized xtensa implementation of pruning to DECODE operator. Makefile changes. Additional unit tests. bug=fixes #3171 * all tests pass * Support for DECODE operator @tensorflow/micro Add reference implementation of Huffman decompression to DECODE operator. Makefile and Bazel BUILD file changes. Additional unit tests. bug=fixes #3187 * Add ScopedMicroProfiler * unfinished merge changes * Split out huffman unit test. Remove xtensa optimizations. * cleanup * Post-review updates.
1 parent aedf929 commit a94423c

File tree

9 files changed

+739
-0
lines changed

9 files changed

+739
-0
lines changed

tensorflow/lite/micro/kernels/BUILD

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -252,6 +252,7 @@ tflm_kernel_cc_library(
252252
"cumsum.cc",
253253
"decode.cc",
254254
"decode_state.cc",
255+
"decode_state_huffman.cc",
255256
"decode_state_lut.cc",
256257
"decode_state_prune.cc",
257258
"depth_to_space.cc",
@@ -346,6 +347,7 @@ tflm_kernel_cc_library(
346347
"circular_buffer.h",
347348
"conv.h",
348349
"decode_state.h",
350+
"decode_state_huffman.h",
349351
"decode_state_lut.h",
350352
"decode_state_prune.h",
351353
"depthwise_conv.h",
@@ -664,6 +666,22 @@ tflm_cc_test(
664666
],
665667
)
666668

669+
tflm_cc_test(
670+
name = "decode_state_huffman_test",
671+
srcs = [
672+
"decode_state_huffman_test.cc",
673+
],
674+
deps = [
675+
":decode_test_helpers",
676+
":kernel_runner",
677+
"//tensorflow/lite/c:common",
678+
"//tensorflow/lite/micro:debug_log",
679+
"//tensorflow/lite/micro:op_resolvers",
680+
"//tensorflow/lite/micro:test_helpers",
681+
"//tensorflow/lite/micro/testing:micro_test",
682+
],
683+
)
684+
667685
tflm_cc_test(
668686
name = "decode_state_lut_test",
669687
srcs = [

tensorflow/lite/micro/kernels/Makefile.inc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,7 @@ $(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/ceil_test.cc \
123123
$(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/comparisons_test.cc \
124124
$(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/concatenation_test.cc \
125125
$(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/cumsum_test.cc \
126+
$(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/decode_state_huffman_test.cc \
126127
$(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/decode_state_lut_test.cc \
127128
$(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/decode_state_prune_test.cc \
128129
$(TENSORFLOW_ROOT)tensorflow/lite/micro/kernels/decode_test.cc \

tensorflow/lite/micro/kernels/decode.cc

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,10 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
8282
dsp = DecodeState::CreateDecodeStatePrune(
8383
context, micro_context->GetAlternateProfiler());
8484
break;
85+
case DecodeState::kDcmTypeHuffman:
86+
dsp = DecodeState::CreateDecodeStateHuffman(
87+
context, micro_context->GetAlternateProfiler());
88+
break;
8589
case DecodeState::kDcmTypeCustom:
8690
MicroPrintf("Custom decode type not yet supported");
8791
break;

tensorflow/lite/micro/kernels/decode_state.cc

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ limitations under the License.
1515

1616
#include "tensorflow/lite/micro/kernels/decode_state.h"
1717

18+
#include "tensorflow/lite/micro/kernels/decode_state_huffman.h"
1819
#include "tensorflow/lite/micro/kernels/decode_state_lut.h"
1920
#include "tensorflow/lite/micro/kernels/decode_state_prune.h"
2021
#include "tensorflow/lite/micro/micro_context.h"
@@ -47,4 +48,17 @@ DecodeState* DecodeState::CreateDecodeStatePrune(
4748
return dsp;
4849
}
4950

51+
DecodeState* DecodeState::CreateDecodeStateHuffman(
52+
const TfLiteContext* context, MicroProfilerInterface* profiler) {
53+
MicroContext* const micro_context = GetMicroContext(context);
54+
void* buffer =
55+
micro_context->AllocatePersistentBuffer(sizeof(DecodeStateHuffman));
56+
if (buffer == nullptr) {
57+
return nullptr;
58+
}
59+
DecodeState* dsp = new (buffer) DecodeStateHuffman(context, profiler);
60+
61+
return dsp;
62+
}
63+
5064
} // namespace tflite

tensorflow/lite/micro/kernels/decode_state.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,8 @@ class DecodeState {
4545
MicroProfilerInterface* profiler);
4646
static DecodeState* CreateDecodeStatePrune(const TfLiteContext* context,
4747
MicroProfilerInterface* profiler);
48+
static DecodeState* CreateDecodeStateHuffman(
49+
const TfLiteContext* context, MicroProfilerInterface* profiler);
4850

4951
static uint8_t Type(const TfLiteTensor& ancillary) {
5052
return GetTensorData<uint8_t>(&ancillary)[kDcmDecodeTypeOffset];
@@ -68,6 +70,7 @@ class DecodeState {
6870
// Decode Common Metadata constants
6971
public:
7072
static constexpr uint8_t kDcmTypeLUT = 0;
73+
static constexpr uint8_t kDcmTypeHuffman = 1;
7174
static constexpr uint8_t kDcmTypePrune = 2;
7275
static constexpr uint8_t kDcmTypeCustom = 127;
7376

Lines changed: 167 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,167 @@
1+
/* Copyright 2025 The TensorFlow Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License.
14+
==============================================================================*/
15+
16+
#include "tensorflow/lite/micro/kernels/decode_state_huffman.h"
17+
18+
#include "tensorflow/lite/kernels/internal/compatibility.h"
19+
#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
20+
#include "tensorflow/lite/kernels/kernel_util.h"
21+
#include "tensorflow/lite/micro/micro_log.h"
22+
#include "tensorflow/lite/micro/micro_profiler.h"
23+
24+
namespace tflite {
25+
26+
TfLiteStatus DecodeStateHuffman::Setup(const TfLiteTensor& input,
27+
const TfLiteTensor& ancillary,
28+
const TfLiteTensor& output) {
29+
const uint8_t* const ancillary_data = GetTensorData<uint8_t>(&ancillary);
30+
if (ancillary_data[kDcmVersionOffset] != 1) {
31+
MicroPrintf("unsupported version %u", ancillary_data[kDcmVersionOffset]);
32+
return kTfLiteError;
33+
}
34+
35+
compressed_codewords_ = GetTensorData<uint32_t>(&input);
36+
count_codewords_ = NumElements(&output);
37+
huffman_tables_ = &ancillary_data[kDcmSizeInBytes];
38+
use_32bit_table_ =
39+
(ancillary_data[kDcmTableSizeOffset] & kDcmTableSize32BitsMask) != 0;
40+
initial_table_size_ =
41+
(ancillary_data[kDcmTableSizeOffset] & kDcmTableSizeInitialMask) >>
42+
kDcmTableSizeInitialShift;
43+
44+
if (!use_32bit_table_) {
45+
TF_LITE_ENSURE_TYPES_EQ(const_cast<TfLiteContext*>(context_), output.type,
46+
kTfLiteInt8);
47+
}
48+
49+
return kTfLiteOk;
50+
}
51+
52+
TfLiteStatus DecodeStateHuffman::Decode(const TfLiteEvalTensor& input,
53+
const TfLiteEvalTensor& ancillary,
54+
const TfLiteEvalTensor& output) {
55+
void* const buffer = const_cast<void*>(micro::GetTensorData<void>(&output));
56+
TFLITE_DCHECK(buffer != nullptr);
57+
58+
switch (output.type) {
59+
case kTfLiteInt8:
60+
if (use_32bit_table_) {
61+
DecompressToBufferWith32BitTable(static_cast<int8_t*>(buffer));
62+
} else {
63+
DecompressToBufferWith16BitTable(static_cast<int8_t*>(buffer));
64+
}
65+
break;
66+
case kTfLiteInt16:
67+
DecompressToBufferWith32BitTable(static_cast<int16_t*>(buffer));
68+
break;
69+
default:
70+
MicroPrintf("unsupported tensor type %s", TfLiteTypeGetName(output.type));
71+
return kTfLiteError;
72+
}
73+
74+
return kTfLiteOk;
75+
}
76+
77+
void DecodeStateHuffman::DecompressToBufferWith16BitTable(int8_t* buffer) {
78+
ScopedMicroProfiler scoped_profiler(__func__, micro_profiler_);
79+
80+
size_t remaining = count_codewords_;
81+
const size_t initial_table_size = initial_table_size_ + 1;
82+
const uint16_t* huffman_tables =
83+
static_cast<const uint16_t*>(huffman_tables_);
84+
uint32_t head_offset = 0; // codewords bitstring state
85+
uint32_t head_hold = 0; // codewords bitstring state
86+
const uint32_t* head_next = nullptr; // codewords bitstring state
87+
uint16_t table_value = 0;
88+
89+
InitNextBits(head_offset, head_hold, head_next);
90+
91+
while (remaining--) {
92+
size_t last_used_bits = initial_table_size;
93+
uint32_t current_index =
94+
GetNextBits(last_used_bits, head_offset, head_hold, head_next);
95+
size_t table_offset = current_index;
96+
table_value = huffman_tables[table_offset];
97+
98+
while (!(table_value & kTable16BitSymbolFoundMask)) {
99+
last_used_bits =
100+
((table_value & kTable16BitCountMask) >> kTable16BitCountShift) + 1;
101+
current_index =
102+
GetNextBits(last_used_bits, head_offset, head_hold, head_next);
103+
const size_t next_table_offset = table_value & kTable16BitValueMask;
104+
table_offset += next_table_offset + current_index;
105+
table_value = huffman_tables[table_offset];
106+
}
107+
108+
*buffer++ = table_value;
109+
110+
const size_t symbol_residual_bits =
111+
(table_value & kTable16BitCountMask) >> kTable16BitCountShift;
112+
if (last_used_bits > symbol_residual_bits) {
113+
PutBackBits(last_used_bits - symbol_residual_bits, head_offset, head_hold,
114+
head_next);
115+
}
116+
}
117+
}
118+
119+
template <typename T>
120+
void DecodeStateHuffman::DecompressToBufferWith32BitTable(T* buffer) {
121+
ScopedMicroProfiler scoped_profiler(__func__, micro_profiler_);
122+
123+
size_t remaining = count_codewords_;
124+
const size_t initial_table_size = initial_table_size_ + 1;
125+
const uint32_t* huffman_tables =
126+
static_cast<const uint32_t*>(huffman_tables_);
127+
uint32_t head_offset = 0; // codewords bitstring state
128+
uint32_t head_hold = 0; // codewords bitstring state
129+
const uint32_t* head_next = nullptr; // codewords bitstring state
130+
uint32_t table_value = 0;
131+
132+
InitNextBits(head_offset, head_hold, head_next);
133+
134+
while (remaining--) {
135+
size_t last_used_bits = initial_table_size;
136+
uint32_t current_index =
137+
GetNextBits(last_used_bits, head_offset, head_hold, head_next);
138+
size_t table_offset = current_index;
139+
table_value = huffman_tables[table_offset];
140+
141+
while (!(table_value & kTable32BitSymbolFoundMask)) {
142+
last_used_bits =
143+
((table_value & kTable32BitCountMask) >> kTable32BitCountShift) + 1;
144+
current_index =
145+
GetNextBits(last_used_bits, head_offset, head_hold, head_next);
146+
const size_t next_table_offset = table_value & kTable32BitValueMask;
147+
table_offset += next_table_offset + current_index;
148+
table_value = huffman_tables[table_offset];
149+
}
150+
151+
*buffer++ = table_value;
152+
153+
const size_t symbol_residual_bits =
154+
(table_value & kTable32BitCountMask) >> kTable32BitCountShift;
155+
if (last_used_bits > symbol_residual_bits) {
156+
PutBackBits(last_used_bits - symbol_residual_bits, head_offset, head_hold,
157+
head_next);
158+
}
159+
}
160+
}
161+
162+
template void DecodeStateHuffman::DecompressToBufferWith32BitTable<int8_t>(
163+
int8_t*);
164+
template void DecodeStateHuffman::DecompressToBufferWith32BitTable<int16_t>(
165+
int16_t*);
166+
167+
} // namespace tflite

0 commit comments

Comments
 (0)