|
| 1 | +/* Copyright 2025 The TensorFlow Authors. All Rights Reserved. |
| 2 | +
|
| 3 | +Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +you may not use this file except in compliance with the License. |
| 5 | +You may obtain a copy of the License at |
| 6 | +
|
| 7 | + http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +
|
| 9 | +Unless required by applicable law or agreed to in writing, software |
| 10 | +distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +See the License for the specific language governing permissions and |
| 13 | +limitations under the License. |
| 14 | +==============================================================================*/ |
| 15 | + |
| 16 | +#include "tensorflow/lite/micro/kernels/decode_state_prune.h" |
| 17 | + |
| 18 | +#include <algorithm> |
| 19 | +#include <cstddef> |
| 20 | + |
| 21 | +#include "tensorflow/lite/kernels/internal/compatibility.h" |
| 22 | +#include "tensorflow/lite/kernels/internal/tensor_ctypes.h" |
| 23 | +#include "tensorflow/lite/kernels/kernel_util.h" |
| 24 | +#include "tensorflow/lite/micro/micro_context.h" |
| 25 | +#include "tensorflow/lite/micro/micro_log.h" |
| 26 | +#include "tensorflow/lite/micro/micro_profiler.h" |
| 27 | + |
| 28 | +namespace tflite { |
| 29 | + |
| 30 | +TfLiteStatus DecodeStatePrune::Setup(const TfLiteTensor& input, |
| 31 | + const TfLiteTensor& ancillary, |
| 32 | + const TfLiteTensor& output) { |
| 33 | + const uint8_t* const ancillary_data = GetTensorData<uint8_t>(&ancillary); |
| 34 | + if (ancillary_data[kDcmVersionOffset] != 1) { |
| 35 | + MicroPrintf("unsupported version %u", ancillary_data[kDcmVersionOffset]); |
| 36 | + return kTfLiteError; |
| 37 | + } |
| 38 | + |
| 39 | + // resolve num_channels_, use_alternate_axis_, and zero points |
| 40 | + if (output.quantization.type == kTfLiteAffineQuantization && |
| 41 | + output.quantization.params != nullptr) { |
| 42 | + const TfLiteAffineQuantization* quantization = |
| 43 | + reinterpret_cast<TfLiteAffineQuantization*>(output.quantization.params); |
| 44 | + num_channels_ = quantization->scale->size; |
| 45 | + if ((quantization->quantized_dimension == output.dims->size - 1) && |
| 46 | + num_channels_ > 1) { |
| 47 | + use_alternate_axis_ = true; |
| 48 | + } else if (quantization->quantized_dimension != 0) { |
| 49 | + MicroPrintf("unsupported quantization axis %u", |
| 50 | + quantization->quantized_dimension); |
| 51 | + return kTfLiteError; |
| 52 | + } |
| 53 | + |
| 54 | + TFLITE_DCHECK(num_channels_ == |
| 55 | + static_cast<size_t>(quantization->zero_point->size)); |
| 56 | + bool has_non_zero_zp = |
| 57 | + std::any_of(quantization->zero_point->data, |
| 58 | + quantization->zero_point->data + num_channels_, |
| 59 | + [](int zp) { return zp != 0; }); |
| 60 | + |
| 61 | + if (output.type != kTfLiteInt8) { |
| 62 | + // make sure all zero points are 0 (zero) |
| 63 | + TF_LITE_ENSURE_MSG(const_cast<TfLiteContext*>(context_), |
| 64 | + has_non_zero_zp == false, |
| 65 | + "All zero-points must be zero"); |
| 66 | + } |
| 67 | + |
| 68 | + if (num_channels_ > 1 && has_non_zero_zp) { |
| 69 | + // copy zero points |
| 70 | + MicroContext* micro_context = GetMicroContext(context_); |
| 71 | + const size_t bufsize = num_channels_ * sizeof(*zero_points_); |
| 72 | + zero_points_ = static_cast<decltype(zero_points_)>( |
| 73 | + micro_context->AllocatePersistentBuffer(bufsize)); |
| 74 | + if (zero_points_ == nullptr) { |
| 75 | + MicroPrintf("unable to allocate zero_points_"); |
| 76 | + return kTfLiteError; |
| 77 | + } |
| 78 | + std::copy_n(quantization->zero_point->data, num_channels_, zero_points_); |
| 79 | + } else { |
| 80 | + single_zero_point_ = quantization->zero_point->data[0]; |
| 81 | + } |
| 82 | + } |
| 83 | + |
| 84 | + compressed_indices_ = GetTensorData<uint8_t>(&input); |
| 85 | + count_indices_ = NumElements(&output); |
| 86 | + elements_per_channel_ = |
| 87 | + use_alternate_axis_ ? 1 : count_indices_ / num_channels_; |
| 88 | + value_table_ = &ancillary_data[kDcmSizeInBytes]; |
| 89 | + |
| 90 | + return kTfLiteOk; |
| 91 | +} |
| 92 | + |
| 93 | +TfLiteStatus DecodeStatePrune::Decode(const TfLiteEvalTensor& input, |
| 94 | + const TfLiteEvalTensor& ancillary, |
| 95 | + const TfLiteEvalTensor& output) { |
| 96 | + void* const buffer = const_cast<void*>(micro::GetTensorData<void>(&output)); |
| 97 | + TFLITE_DCHECK(buffer != nullptr); |
| 98 | + |
| 99 | + switch (output.type) { |
| 100 | + case kTfLiteBool: |
| 101 | + DecompressToBuffer<int8_t>(buffer); |
| 102 | + break; |
| 103 | + case kTfLiteFloat32: |
| 104 | + DecompressToBuffer<int32_t>(buffer); |
| 105 | + break; |
| 106 | + case kTfLiteInt8: |
| 107 | + if (num_channels_ > 1 && zero_points_ != nullptr) { |
| 108 | + DecompressToBufferPerChannelInt8(buffer); |
| 109 | + } else { |
| 110 | + DecompressToBuffer<int8_t>(buffer); |
| 111 | + } |
| 112 | + break; |
| 113 | + case kTfLiteInt16: |
| 114 | + DecompressToBuffer<int16_t>(buffer); |
| 115 | + break; |
| 116 | + case kTfLiteInt32: |
| 117 | + DecompressToBuffer<int32_t>(buffer); |
| 118 | + break; |
| 119 | + case kTfLiteInt64: |
| 120 | + DecompressToBuffer<int64_t>(buffer); |
| 121 | + break; |
| 122 | + default: |
| 123 | + MicroPrintf("unsupported tensor type %s", TfLiteTypeGetName(output.type)); |
| 124 | + return kTfLiteError; |
| 125 | + } |
| 126 | + |
| 127 | + return kTfLiteOk; |
| 128 | +} |
| 129 | + |
| 130 | +template <typename T> |
| 131 | +void DecodeStatePrune::DecompressToBuffer(void* vp) { |
| 132 | + ScopedMicroProfiler scoped_profiler(__func__, micro_profiler_); |
| 133 | + |
| 134 | + T* buffer = static_cast<T*>(vp); |
| 135 | + const T* value_table = static_cast<const T*>(value_table_); |
| 136 | + const size_t max_count = count_indices_; |
| 137 | + const uint8_t* const indices = compressed_indices_; |
| 138 | + |
| 139 | + for (size_t index = 0; index < max_count; index++) { |
| 140 | + size_t shift = ~index & 0b111; |
| 141 | + size_t is_not_zp = (indices[index >> 3] >> shift) & 0b1; |
| 142 | + |
| 143 | + if (is_not_zp) { |
| 144 | + *buffer++ = *value_table++; |
| 145 | + } else { |
| 146 | + *buffer++ = single_zero_point_; |
| 147 | + } |
| 148 | + } |
| 149 | +} |
| 150 | + |
| 151 | +void DecodeStatePrune::DecompressToBufferPerChannelInt8(void* vp) { |
| 152 | + TFLITE_DCHECK(zero_points_ != nullptr); |
| 153 | + ScopedMicroProfiler scoped_profiler(__func__, micro_profiler_); |
| 154 | + |
| 155 | + int8_t* buffer = static_cast<int8_t*>(vp); |
| 156 | + size_t current_offset = 0; |
| 157 | + const uint8_t* const indices = compressed_indices_; |
| 158 | + const int8_t* value_table = static_cast<const int8_t*>(value_table_); |
| 159 | + |
| 160 | + if (use_alternate_axis_) { |
| 161 | + const size_t max_channels = num_channels_; |
| 162 | + size_t count = count_indices_; |
| 163 | + |
| 164 | + while (count > 0) { |
| 165 | + for (size_t channel = 0; channel < max_channels; channel++) { |
| 166 | + const int8_t zp = zero_points_[channel]; |
| 167 | + size_t shift = ~current_offset & 0b111; |
| 168 | + size_t is_not_zp = (indices[current_offset >> 3] >> shift) & 0b1; |
| 169 | + |
| 170 | + if (is_not_zp) { |
| 171 | + *buffer++ = *value_table++; |
| 172 | + } else { |
| 173 | + *buffer++ = zp; |
| 174 | + } |
| 175 | + current_offset++; |
| 176 | + } |
| 177 | + count -= max_channels; |
| 178 | + } |
| 179 | + } else { |
| 180 | + const size_t max_count = elements_per_channel_; |
| 181 | + |
| 182 | + for (size_t channel = 0; channel < num_channels_; channel++) { |
| 183 | + size_t count = max_count; |
| 184 | + const int8_t zp = zero_points_[channel]; |
| 185 | + |
| 186 | + while (count-- > 0) { |
| 187 | + size_t shift = ~current_offset & 0b111; |
| 188 | + size_t is_not_zp = (indices[current_offset >> 3] >> shift) & 0b1; |
| 189 | + |
| 190 | + if (is_not_zp) { |
| 191 | + *buffer++ = *value_table++; |
| 192 | + } else { |
| 193 | + *buffer++ = zp; |
| 194 | + } |
| 195 | + current_offset++; |
| 196 | + } |
| 197 | + } |
| 198 | + } |
| 199 | +} |
| 200 | + |
| 201 | +template void DecodeStatePrune::DecompressToBuffer<int8_t>(void*); |
| 202 | +template void DecodeStatePrune::DecompressToBuffer<int16_t>(void*); |
| 203 | +template void DecodeStatePrune::DecompressToBuffer<int32_t>(void*); |
| 204 | +template void DecodeStatePrune::DecompressToBuffer<int64_t>(void*); |
| 205 | + |
| 206 | +} // namespace tflite |
0 commit comments