Skip to content

Commit 08c536b

Browse files
authored
Use the latest API
1 parent 6706c1e commit 08c536b

File tree

1 file changed

+18
-21
lines changed

1 file changed

+18
-21
lines changed

tensorflow_lite_support/cc/task/processor/image_postprocessor.cc

Lines changed: 18 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -79,11 +79,8 @@ GetNormalizationOptionsIfAny(const TensorMetadata& tensor_metadata) {
7979
tflite::support::StatusOr<std::unique_ptr<ImagePostprocessor>>
8080
ImagePostprocessor::Create(core::TfLiteEngine* engine,
8181
const std::initializer_list<int> output_indices,
82-
const std::initializer_list<int> input_indices) {
83-
RETURN_IF_ERROR(Postprocessor::SanityCheck(/* num_expected_tensors = */ 1,
84-
engine, output_indices));
85-
auto processor =
86-
absl::WrapUnique(new ImagePostprocessor(engine, output_indices));
82+
const std::initializer_list<int> input_indices) {
83+
ASSIGN_OR_RETURN(auto processor, Processor::Create<ImagePostprocessor>(/* num_expected_tensors = */ 1, engine, output_indices, /* requires_metadata = */ false));
8784

8885
RETURN_IF_ERROR(processor->Init(input_indices));
8986
return processor;
@@ -100,32 +97,32 @@ absl::Status ImagePostprocessor::Init(const std::vector<int>& input_indices) {
10097
tflite::support::TfLiteSupportStatus::kInvalidNumOutputTensorsError);
10198
}
10299

103-
if (Tensor()->type != kTfLiteUInt8 && Tensor()->type != kTfLiteFloat32) {
100+
if (GetTensor()->type != kTfLiteUInt8 && GetTensor()->type != kTfLiteFloat32) {
104101
return tflite::support::CreateStatusWithPayload(
105102
absl::StatusCode::kInvalidArgument,
106103
absl::StrFormat("Type mismatch for output tensor %s. Requested one "
107104
"of these types: "
108105
"kTfLiteUint8/kTfLiteFloat32, got %s.",
109-
Tensor()->name, TfLiteTypeGetName(Tensor()->type)),
106+
GetTensor()->name, TfLiteTypeGetName(GetTensor()->type)),
110107
tflite::support::TfLiteSupportStatus::kInvalidOutputTensorTypeError);
111108
}
112109

113-
if (Tensor()->dims->data[0] != 1 || Tensor()->dims->data[3] != 3) {
110+
if (GetTensor()->dims->data[0] != 1 || GetTensor()->dims->data[3] != 3) {
114111
return CreateStatusWithPayload(
115112
absl::StatusCode::kInvalidArgument,
116113
absl::StrCat("The input tensor should have dimensions 1 x height x "
117114
"width x 3. Got ",
118-
Tensor()->dims->data[0], " x ", Tensor()->dims->data[1],
119-
" x ", Tensor()->dims->data[2], " x ",
120-
Tensor()->dims->data[3], "."),
115+
GetTensor()->dims->data[0], " x ", GetTensor()->dims->data[1],
116+
" x ", GetTensor()->dims->data[2], " x ",
117+
GetTensor()->dims->data[3], "."),
121118
tflite::support::TfLiteSupportStatus::
122119
kInvalidInputTensorDimensionsError);
123120
}
124121

125122
// Gather metadata
126123
auto* output_metadata =
127124
engine_->metadata_extractor()->GetOutputTensorMetadata(
128-
output_indices_.at(0));
125+
tensor_indices_.at(0));
129126
auto* input_metadata = engine_->metadata_extractor()->GetInputTensorMetadata(
130127
input_indices.at(0));
131128

@@ -137,14 +134,14 @@ absl::Status ImagePostprocessor::Init(const std::vector<int>& input_indices) {
137134
ASSIGN_OR_RETURN(normalization_options,
138135
GetNormalizationOptionsIfAny(*processing_metadata));
139136

140-
if (Tensor()->type == kTfLiteFloat32) {
137+
if (GetTensor()->type == kTfLiteFloat32) {
141138
if (!normalization_options.has_value()) {
142139
return CreateStatusWithPayload(
143140
absl::StatusCode::kNotFound,
144141
"Output tensor has type kTfLiteFloat32: it requires specifying "
145142
"NormalizationOptions metadata to preprocess output images.",
146143
TfLiteSupportStatus::kMetadataMissingNormalizationOptionsError);
147-
} else if (Tensor()->bytes / sizeof(float) %
144+
} else if (GetTensor()->bytes / sizeof(float) %
148145
normalization_options.value().num_values !=
149146
0) {
150147
return CreateStatusWithPayload(
@@ -162,28 +159,28 @@ absl::Status ImagePostprocessor::Init(const std::vector<int>& input_indices) {
162159
}
163160

164161
absl::StatusOr<vision::FrameBuffer> ImagePostprocessor::Postprocess() {
165-
has_uint8_outputs_ = Tensor()->type == kTfLiteUInt8;
162+
has_uint8_outputs_ = GetTensor()->type == kTfLiteUInt8;
166163
const int kRgbPixelBytes = 3;
167164

168165
vision::FrameBuffer::Dimension to_buffer_dimension = {
169-
Tensor()->dims->data[2], Tensor()->dims->data[1]};
166+
GetTensor()->dims->data[2], GetTensor()->dims->data[1]};
170167
size_t output_byte_size =
171168
GetBufferByteSize(to_buffer_dimension, vision::FrameBuffer::Format::kRGB);
172169
std::vector<uint8> postprocessed_data(output_byte_size / sizeof(uint8), 0);
173170

174171
if (has_uint8_outputs_) { // No normalization required.
175-
if (Tensor()->bytes != output_byte_size) {
172+
if (GetTensor()->bytes != output_byte_size) {
176173
return tflite::support::CreateStatusWithPayload(
177174
absl::StatusCode::kInternal,
178175
"Size mismatch or unsupported padding bytes between pixel data "
179176
"and output tensor.");
180177
}
181178
const uint8* output_data =
182-
core::AssertAndReturnTypedTensor<uint8>(Tensor());
179+
core::AssertAndReturnTypedTensor<uint8>(GetTensor()).value();
183180
postprocessed_data.insert(postprocessed_data.begin(), &output_data[0],
184181
&output_data[output_byte_size / sizeof(uint8)]);
185182
} else { // Denormalize to [0, 255] range.
186-
if (Tensor()->bytes / sizeof(float) != output_byte_size / sizeof(uint8)) {
183+
if (GetTensor()->bytes / sizeof(float) != output_byte_size / sizeof(uint8)) {
187184
return tflite::support::CreateStatusWithPayload(
188185
absl::StatusCode::kInternal,
189186
"Size mismatch or unsupported padding bytes between pixel data "
@@ -192,7 +189,7 @@ absl::StatusOr<vision::FrameBuffer> ImagePostprocessor::Postprocess() {
192189

193190
uint8* denormalized_output_data = postprocessed_data.data();
194191
const float* output_data =
195-
core::AssertAndReturnTypedTensor<float>(Tensor());
192+
core::AssertAndReturnTypedTensor<float>(GetTensor()).value();
196193
const auto norm_options = GetNormalizationOptions();
197194

198195
if (norm_options.num_values == 1) {
@@ -217,7 +214,7 @@ absl::StatusOr<vision::FrameBuffer> ImagePostprocessor::Postprocess() {
217214

218215
vision::FrameBuffer::Plane postprocessed_plane = {
219216
/*buffer=*/postprocessed_data.data(),
220-
/*stride=*/{Tensor()->dims->data[2] * kRgbPixelBytes, kRgbPixelBytes}};
217+
/*stride=*/{GetTensor()->dims->data[2] * kRgbPixelBytes, kRgbPixelBytes}};
221218
auto postprocessed_frame_buffer =
222219
vision::FrameBuffer::Create({postprocessed_plane}, to_buffer_dimension,
223220
vision::FrameBuffer::Format::kRGB,

0 commit comments

Comments
 (0)