Skip to content

Commit a8e0dbe

Browse files
authored
Add a BertPreprocessor class (#343)
* Add a BertTokenizer class * Address review comments
1 parent 7bbaf9b commit a8e0dbe

File tree

6 files changed

+388
-107
lines changed

6 files changed

+388
-107
lines changed

keras_nlp/layers/multi_segment_packer.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ class MultiSegmentPacker(keras.layers.Layer):
3030
Takes as input a list or tuple of token segments. The layer will process
3131
inputs as follows:
3232
- Truncate all input segments to fit within `sequence_length` according to
33-
the `truncator` strategy.
33+
the `truncate` strategy.
3434
- Concatenate all input segments, adding a single `start_value` at the
3535
start of the entire sequence, and multiple `end_value`s at the end of
3636
each segment.
@@ -55,7 +55,7 @@ class MultiSegmentPacker(keras.layers.Layer):
5555
pad_value: The id or token that is to be placed into the unused
5656
positions after the last segment in the sequence
5757
(called "[PAD]" for BERT).
58-
truncator: The algorithm to truncate a list of batched segments to fit a
58+
truncate: The algorithm to truncate a list of batched segments to fit a
5959
per-example length limit. The value can be either `round_robin` or
6060
`waterfall`:
6161
- `"round_robin"`: Available space is assigned one token at a
@@ -104,17 +104,17 @@ def __init__(
104104
start_value,
105105
end_value,
106106
pad_value=None,
107-
truncator="round_robin",
107+
truncate="round_robin",
108108
**kwargs,
109109
):
110110
super().__init__(**kwargs)
111111
self.sequence_length = sequence_length
112-
if truncator not in ("round_robin", "waterfall"):
112+
if truncate not in ("round_robin", "waterfall"):
113113
raise ValueError(
114114
"Only 'round_robin' and 'waterfall' algorithms are "
115-
"supported. Received %s" % truncator
115+
"supported. Received %s" % truncate
116116
)
117-
self.truncator = truncator
117+
self.truncate = truncate
118118
self.start_value = start_value
119119
self.end_value = end_value
120120
self.pad_value = pad_value
@@ -127,7 +127,7 @@ def get_config(self):
127127
"start_value": self.start_value,
128128
"end_value": self.end_value,
129129
"pad_value": self.pad_value,
130-
"truncator": self.truncator,
130+
"truncate": self.truncate,
131131
}
132132
)
133133
return config
@@ -162,16 +162,16 @@ def _convert_dense(self, x):
162162
def _trim_inputs(self, inputs):
163163
"""Trim inputs to desired length."""
164164
num_special_tokens = len(inputs) + 1
165-
if self.truncator == "round_robin":
165+
if self.truncate == "round_robin":
166166
return tf_text.RoundRobinTrimmer(
167167
self.sequence_length - num_special_tokens
168168
).trim(inputs)
169-
elif self.truncator == "waterfall":
169+
elif self.truncate == "waterfall":
170170
return tf_text.WaterfallTrimmer(
171171
self.sequence_length - num_special_tokens
172172
).trim(inputs)
173173
else:
174-
raise ValueError("Unsupported truncator: %s" % self.truncator)
174+
raise ValueError("Unsupported truncate: %s" % self.truncate)
175175

176176
def _combine_inputs(self, segments):
177177
"""Combine inputs with start and end values added."""

keras_nlp/layers/multi_segment_packer_test.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ def test_trim_multiple_inputs_round_robin(self):
4343
seq1 = tf.constant(["a", "b", "c"])
4444
seq2 = tf.constant(["x", "y", "z"])
4545
packer = MultiSegmentPacker(
46-
7, start_value="[CLS]", end_value="[SEP]", truncator="round_robin"
46+
7, start_value="[CLS]", end_value="[SEP]", truncate="round_robin"
4747
)
4848
output = packer([seq1, seq2])
4949
self.assertAllEqual(
@@ -58,7 +58,7 @@ def test_trim_multiple_inputs_waterfall(self):
5858
seq1 = tf.constant(["a", "b", "c"])
5959
seq2 = tf.constant(["x", "y", "z"])
6060
packer = MultiSegmentPacker(
61-
7, start_value="[CLS]", end_value="[SEP]", truncator="waterfall"
61+
7, start_value="[CLS]", end_value="[SEP]", truncate="waterfall"
6262
)
6363
output = packer([seq1, seq2])
6464
self.assertAllEqual(
@@ -73,7 +73,7 @@ def test_trim_batched_inputs_round_robin(self):
7373
seq1 = tf.constant([["a", "b", "c"], ["a", "b", "c"]])
7474
seq2 = tf.constant([["x", "y", "z"], ["x", "y", "z"]])
7575
packer = MultiSegmentPacker(
76-
7, start_value="[CLS]", end_value="[SEP]", truncator="round_robin"
76+
7, start_value="[CLS]", end_value="[SEP]", truncate="round_robin"
7777
)
7878
output = packer([seq1, seq2])
7979
self.assertAllEqual(
@@ -94,7 +94,7 @@ def test_trim_batched_inputs_waterfall(self):
9494
seq1 = tf.ragged.constant([["a", "b", "c"], ["a", "b"]])
9595
seq2 = tf.constant([["x", "y", "z"], ["x", "y", "z"]])
9696
packer = MultiSegmentPacker(
97-
7, start_value="[CLS]", end_value="[SEP]", truncator="waterfall"
97+
7, start_value="[CLS]", end_value="[SEP]", truncate="waterfall"
9898
)
9999
output = packer([seq1, seq2])
100100
self.assertAllEqual(
@@ -151,7 +151,7 @@ def test_config(self):
151151
seq1 = tf.ragged.constant([["a", "b", "c"], ["a", "b"]])
152152
seq2 = tf.ragged.constant([["x", "y", "z"], ["x", "y", "z"]])
153153
original_packer = MultiSegmentPacker(
154-
7, start_value="[CLS]", end_value="[SEP]", truncator="waterfall"
154+
7, start_value="[CLS]", end_value="[SEP]", truncate="waterfall"
155155
)
156156
cloned_packer = MultiSegmentPacker.from_config(
157157
original_packer.get_config()
@@ -166,7 +166,7 @@ def test_saving(self, format):
166166
seq1 = tf.ragged.constant([["a", "b", "c"], ["a", "b"]])
167167
seq2 = tf.ragged.constant([["x", "y", "z"], ["x", "y", "z"]])
168168
packer = MultiSegmentPacker(
169-
7, start_value="[CLS]", end_value="[SEP]", truncator="waterfall"
169+
7, start_value="[CLS]", end_value="[SEP]", truncate="waterfall"
170170
)
171171
inputs = (
172172
keras.Input(dtype="string", ragged=True, shape=(None,)),

keras_nlp/models/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from keras_nlp.models.bert import BertCustom
1818
from keras_nlp.models.bert import BertLarge
1919
from keras_nlp.models.bert import BertMedium
20+
from keras_nlp.models.bert import BertPreprocessor
2021
from keras_nlp.models.bert import BertSmall
2122
from keras_nlp.models.bert import BertTiny
2223
from keras_nlp.models.roberta import RobertaBase

0 commit comments

Comments
 (0)