@@ -79,11 +79,8 @@ GetNormalizationOptionsIfAny(const TensorMetadata& tensor_metadata) {
79
79
tflite::support::StatusOr<std::unique_ptr<ImagePostprocessor>>
80
80
ImagePostprocessor::Create (core::TfLiteEngine* engine,
81
81
const std::initializer_list<int > output_indices,
82
- const std::initializer_list<int > input_indices) {
83
- RETURN_IF_ERROR (Postprocessor::SanityCheck (/* num_expected_tensors = */ 1 ,
84
- engine, output_indices));
85
- auto processor =
86
- absl::WrapUnique (new ImagePostprocessor (engine, output_indices));
82
+ const std::initializer_list<int > input_indices) {
83
+ ASSIGN_OR_RETURN (auto processor, Processor::Create<ImagePostprocessor>(/* num_expected_tensors = */ 1 , engine, output_indices, /* requires_metadata = */ false ));
87
84
88
85
RETURN_IF_ERROR (processor->Init (input_indices));
89
86
return processor;
@@ -100,32 +97,32 @@ absl::Status ImagePostprocessor::Init(const std::vector<int>& input_indices) {
100
97
tflite::support::TfLiteSupportStatus::kInvalidNumOutputTensorsError );
101
98
}
102
99
103
- if (Tensor ()->type != kTfLiteUInt8 && Tensor ()->type != kTfLiteFloat32 ) {
100
+ if (GetTensor ()->type != kTfLiteUInt8 && GetTensor ()->type != kTfLiteFloat32 ) {
104
101
return tflite::support::CreateStatusWithPayload (
105
102
absl::StatusCode::kInvalidArgument ,
106
103
absl::StrFormat (" Type mismatch for output tensor %s. Requested one "
107
104
" of these types: "
108
105
" kTfLiteUint8/kTfLiteFloat32, got %s." ,
109
- Tensor ()->name , TfLiteTypeGetName (Tensor ()->type )),
106
+ GetTensor ()->name , TfLiteTypeGetName (GetTensor ()->type )),
110
107
tflite::support::TfLiteSupportStatus::kInvalidOutputTensorTypeError );
111
108
}
112
109
113
- if (Tensor ()->dims ->data [0 ] != 1 || Tensor ()->dims ->data [3 ] != 3 ) {
110
+ if (GetTensor ()->dims ->data [0 ] != 1 || GetTensor ()->dims ->data [3 ] != 3 ) {
114
111
return CreateStatusWithPayload (
115
112
absl::StatusCode::kInvalidArgument ,
116
113
absl::StrCat (" The input tensor should have dimensions 1 x height x "
117
114
" width x 3. Got " ,
118
- Tensor ()->dims ->data [0 ], " x " , Tensor ()->dims ->data [1 ],
119
- " x " , Tensor ()->dims ->data [2 ], " x " ,
120
- Tensor ()->dims ->data [3 ], " ." ),
115
+ GetTensor ()->dims ->data [0 ], " x " , GetTensor ()->dims ->data [1 ],
116
+ " x " , GetTensor ()->dims ->data [2 ], " x " ,
117
+ GetTensor ()->dims ->data [3 ], " ." ),
121
118
tflite::support::TfLiteSupportStatus::
122
119
kInvalidInputTensorDimensionsError );
123
120
}
124
121
125
122
// Gather metadata
126
123
auto * output_metadata =
127
124
engine_->metadata_extractor ()->GetOutputTensorMetadata (
128
- output_indices_ .at (0 ));
125
+ tensor_indices_ .at (0 ));
129
126
auto * input_metadata = engine_->metadata_extractor ()->GetInputTensorMetadata (
130
127
input_indices.at (0 ));
131
128
@@ -137,14 +134,14 @@ absl::Status ImagePostprocessor::Init(const std::vector<int>& input_indices) {
137
134
ASSIGN_OR_RETURN (normalization_options,
138
135
GetNormalizationOptionsIfAny (*processing_metadata));
139
136
140
- if (Tensor ()->type == kTfLiteFloat32 ) {
137
+ if (GetTensor ()->type == kTfLiteFloat32 ) {
141
138
if (!normalization_options.has_value ()) {
142
139
return CreateStatusWithPayload (
143
140
absl::StatusCode::kNotFound ,
144
141
" Output tensor has type kTfLiteFloat32: it requires specifying "
145
142
" NormalizationOptions metadata to preprocess output images." ,
146
143
TfLiteSupportStatus::kMetadataMissingNormalizationOptionsError );
147
- } else if (Tensor ()->bytes / sizeof (float ) %
144
+ } else if (GetTensor ()->bytes / sizeof (float ) %
148
145
normalization_options.value ().num_values !=
149
146
0 ) {
150
147
return CreateStatusWithPayload (
@@ -162,28 +159,28 @@ absl::Status ImagePostprocessor::Init(const std::vector<int>& input_indices) {
162
159
}
163
160
164
161
absl::StatusOr<vision::FrameBuffer> ImagePostprocessor::Postprocess () {
165
- has_uint8_outputs_ = Tensor ()->type == kTfLiteUInt8 ;
162
+ has_uint8_outputs_ = GetTensor ()->type == kTfLiteUInt8 ;
166
163
const int kRgbPixelBytes = 3 ;
167
164
168
165
vision::FrameBuffer::Dimension to_buffer_dimension = {
169
- Tensor ()->dims ->data [2 ], Tensor ()->dims ->data [1 ]};
166
+ GetTensor ()->dims ->data [2 ], GetTensor ()->dims ->data [1 ]};
170
167
size_t output_byte_size =
171
168
GetBufferByteSize (to_buffer_dimension, vision::FrameBuffer::Format::kRGB );
172
169
std::vector<uint8> postprocessed_data (output_byte_size / sizeof (uint8), 0 );
173
170
174
171
if (has_uint8_outputs_) { // No normalization required.
175
- if (Tensor ()->bytes != output_byte_size) {
172
+ if (GetTensor ()->bytes != output_byte_size) {
176
173
return tflite::support::CreateStatusWithPayload (
177
174
absl::StatusCode::kInternal ,
178
175
" Size mismatch or unsupported padding bytes between pixel data "
179
176
" and output tensor." );
180
177
}
181
178
const uint8* output_data =
182
- core::AssertAndReturnTypedTensor<uint8>(Tensor () );
179
+ core::AssertAndReturnTypedTensor<uint8>(GetTensor ()). value ( );
183
180
postprocessed_data.insert (postprocessed_data.begin (), &output_data[0 ],
184
181
&output_data[output_byte_size / sizeof (uint8)]);
185
182
} else { // Denormalize to [0, 255] range.
186
- if (Tensor ()->bytes / sizeof (float ) != output_byte_size / sizeof (uint8)) {
183
+ if (GetTensor ()->bytes / sizeof (float ) != output_byte_size / sizeof (uint8)) {
187
184
return tflite::support::CreateStatusWithPayload (
188
185
absl::StatusCode::kInternal ,
189
186
" Size mismatch or unsupported padding bytes between pixel data "
@@ -192,7 +189,7 @@ absl::StatusOr<vision::FrameBuffer> ImagePostprocessor::Postprocess() {
192
189
193
190
uint8* denormalized_output_data = postprocessed_data.data ();
194
191
const float * output_data =
195
- core::AssertAndReturnTypedTensor<float >(Tensor () );
192
+ core::AssertAndReturnTypedTensor<float >(GetTensor ()). value ( );
196
193
const auto norm_options = GetNormalizationOptions ();
197
194
198
195
if (norm_options.num_values == 1 ) {
@@ -217,7 +214,7 @@ absl::StatusOr<vision::FrameBuffer> ImagePostprocessor::Postprocess() {
217
214
218
215
vision::FrameBuffer::Plane postprocessed_plane = {
219
216
/* buffer=*/ postprocessed_data.data (),
220
- /* stride=*/ {Tensor ()->dims ->data [2 ] * kRgbPixelBytes , kRgbPixelBytes }};
217
+ /* stride=*/ {GetTensor ()->dims ->data [2 ] * kRgbPixelBytes , kRgbPixelBytes }};
221
218
auto postprocessed_frame_buffer =
222
219
vision::FrameBuffer::Create ({postprocessed_plane}, to_buffer_dimension,
223
220
vision::FrameBuffer::Format::kRGB ,
0 commit comments