@@ -27,16 +27,18 @@ See the License for the specific language governing permissions and
2727limitations under the License.
2828==============================================================================*/
2929
30+ #include < cstdint>
3031#include < iterator>
32+ #include < limits>
3133#include < vector>
3234
33- #include " tensorflow_text/core/kernels/sentencepiece/optimized_encoder.h"
34- #include " tensorflow_text/core/kernels/sentencepiece/sentencepiece_tokenizer.h"
3535#include " tensorflow/core/framework/op.h"
3636#include " tensorflow/core/framework/op_kernel.h"
3737#include " tensorflow/core/framework/shape_inference.h"
3838#include " tensorflow/core/framework/tensor.h"
3939#include " tensorflow/core/platform/errors.h"
40+ #include " tensorflow_text/core/kernels/sentencepiece/optimized_encoder.h"
41+ #include " tensorflow_text/core/kernels/sentencepiece/sentencepiece_tokenizer.h"
4042
4143namespace tensorflow {
4244namespace text {
@@ -50,7 +52,7 @@ class TFSentencepieceOp : public tensorflow::OpKernel {
5052 const auto & input_values_tensor = ctx->input (kInputIndex );
5153 const auto input_values_flat =
5254 input_values_tensor.flat <tensorflow::tstring>();
53- const int num_of_input_values = input_values_flat.size ();
55+ const int64_t num_of_input_values = input_values_flat.size ();
5456
5557 const auto & add_bos_tensor = ctx->input (kAddBOSInput );
5658 const bool add_bos = add_bos_tensor.scalar <bool >()();
@@ -74,20 +76,26 @@ class TFSentencepieceOp : public tensorflow::OpKernel {
7476 }
7577 tensorflow::Tensor* output_values_tensor = nullptr ;
7678 tensorflow::Tensor* output_splits_tensor = nullptr ;
77-
79+ OP_REQUIRES (ctx, encoded.size () < std::numeric_limits<int32_t >::max (),
80+ errors::InvalidArgument (
81+ " Encoded input must contain less than 2^31 characters." ));
82+ OP_REQUIRES (
83+ ctx, splits.size () + 1 < std::numeric_limits<int32_t >::max (),
84+ errors::InvalidArgument (" Splits tensor is limited to 2^31-1 values." ));
7885 OP_REQUIRES_OK (
79- ctx, ctx->allocate_output (0 , {( int16_t ) encoded.size ()},
86+ ctx, ctx->allocate_output (0 , {static_cast < int32_t >( encoded.size () )},
8087 &output_values_tensor));
81- OP_REQUIRES_OK (ctx, ctx->allocate_output (1 , {(int16_t )splits.size () + 1 },
82- &output_splits_tensor));
88+ OP_REQUIRES_OK (
89+ ctx, ctx->allocate_output (1 , {static_cast <int32_t >(splits.size ()) + 1 },
90+ &output_splits_tensor));
8391
8492 auto values_tensor_flat = output_values_tensor->vec <int32>();
8593 auto splits_tensor_flat = output_splits_tensor->vec <int32>();
86- for (int i = 0 ; i < encoded.size (); ++i) {
94+ for (int32_t i = 0 ; i < encoded.size (); ++i) {
8795 values_tensor_flat (i) = encoded[i];
8896 }
8997 splits_tensor_flat (0 ) = 0 ;
90- for (int i = 0 ; i < splits.size (); ++i) {
98+ for (int32_t i = 0 ; i < splits.size (); ++i) {
9199 splits_tensor_flat (i + 1 ) = splits[i];
92100 }
93101 }
0 commit comments