@@ -43,7 +43,7 @@ def test_trim_multiple_inputs_round_robin(self):
4343 seq1 = tf .constant (["a" , "b" , "c" ])
4444 seq2 = tf .constant (["x" , "y" , "z" ])
4545 packer = MultiSegmentPacker (
46- 7 , start_value = "[CLS]" , end_value = "[SEP]" , truncator = "round_robin"
46+ 7 , start_value = "[CLS]" , end_value = "[SEP]" , truncate = "round_robin"
4747 )
4848 output = packer ([seq1 , seq2 ])
4949 self .assertAllEqual (
@@ -58,7 +58,7 @@ def test_trim_multiple_inputs_waterfall(self):
5858 seq1 = tf .constant (["a" , "b" , "c" ])
5959 seq2 = tf .constant (["x" , "y" , "z" ])
6060 packer = MultiSegmentPacker (
61- 7 , start_value = "[CLS]" , end_value = "[SEP]" , truncator = "waterfall"
61+ 7 , start_value = "[CLS]" , end_value = "[SEP]" , truncate = "waterfall"
6262 )
6363 output = packer ([seq1 , seq2 ])
6464 self .assertAllEqual (
@@ -73,7 +73,7 @@ def test_trim_batched_inputs_round_robin(self):
7373 seq1 = tf .constant ([["a" , "b" , "c" ], ["a" , "b" , "c" ]])
7474 seq2 = tf .constant ([["x" , "y" , "z" ], ["x" , "y" , "z" ]])
7575 packer = MultiSegmentPacker (
76- 7 , start_value = "[CLS]" , end_value = "[SEP]" , truncator = "round_robin"
76+ 7 , start_value = "[CLS]" , end_value = "[SEP]" , truncate = "round_robin"
7777 )
7878 output = packer ([seq1 , seq2 ])
7979 self .assertAllEqual (
@@ -94,7 +94,7 @@ def test_trim_batched_inputs_waterfall(self):
9494 seq1 = tf .ragged .constant ([["a" , "b" , "c" ], ["a" , "b" ]])
9595 seq2 = tf .constant ([["x" , "y" , "z" ], ["x" , "y" , "z" ]])
9696 packer = MultiSegmentPacker (
97- 7 , start_value = "[CLS]" , end_value = "[SEP]" , truncator = "waterfall"
97+ 7 , start_value = "[CLS]" , end_value = "[SEP]" , truncate = "waterfall"
9898 )
9999 output = packer ([seq1 , seq2 ])
100100 self .assertAllEqual (
@@ -151,7 +151,7 @@ def test_config(self):
151151 seq1 = tf .ragged .constant ([["a" , "b" , "c" ], ["a" , "b" ]])
152152 seq2 = tf .ragged .constant ([["x" , "y" , "z" ], ["x" , "y" , "z" ]])
153153 original_packer = MultiSegmentPacker (
154- 7 , start_value = "[CLS]" , end_value = "[SEP]" , truncator = "waterfall"
154+ 7 , start_value = "[CLS]" , end_value = "[SEP]" , truncate = "waterfall"
155155 )
156156 cloned_packer = MultiSegmentPacker .from_config (
157157 original_packer .get_config ()
@@ -166,7 +166,7 @@ def test_saving(self, format):
166166 seq1 = tf .ragged .constant ([["a" , "b" , "c" ], ["a" , "b" ]])
167167 seq2 = tf .ragged .constant ([["x" , "y" , "z" ], ["x" , "y" , "z" ]])
168168 packer = MultiSegmentPacker (
169- 7 , start_value = "[CLS]" , end_value = "[SEP]" , truncator = "waterfall"
169+ 7 , start_value = "[CLS]" , end_value = "[SEP]" , truncate = "waterfall"
170170 )
171171 inputs = (
172172 keras .Input (dtype = "string" , ragged = True , shape = (None ,)),
0 commit comments