Skip to content

Commit a009455

Browse files
fix: Fix IntegerLookup docstring output shape documentation and add test coverage (#21625)
1 parent 1fc411f commit a009455

File tree

2 files changed

+55
-3
lines changed

2 files changed

+55
-3
lines changed

keras/src/layers/preprocessing/integer_lookup.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -111,9 +111,12 @@ class IntegerLookup(IndexLookup):
111111
appeared in the sample.
112112
- `"tf_idf"`: As `"multi_hot"`, but the TF-IDF algorithm is
113113
applied to find the value in each token slot.
114-
For `"int"` output, any shape of input and output is supported.
115-
For all other output modes, currently only output up to rank 2
116-
is supported. Defaults to `"int"`.
114+
For `"int"` output, the output shape matches the input shape.
115+
For `"one_hot"` output, the output shape is
116+
`input_shape + (vocabulary_size,)`, where `input_shape` may
117+
have arbitrary rank. For other output modes (`"multi_hot"`,
118+
`"count"`, `"tf_idf"`), the output shape is `(batch_size,
119+
vocabulary_size)`. Defaults to `"int"`.
117120
pad_to_max_tokens: Only applicable when `output_mode` is `"multi_hot"`,
118121
`"count"`, or `"tf_idf"`. If `True`, the output will have
119122
its feature axis padded to `max_tokens` even if the number

keras/src/layers/preprocessing/integer_lookup_test.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,3 +104,52 @@ def test_tf_data_compatibility(self):
104104
ds = tf_data.Dataset.from_tensor_slices(input_data).batch(4).map(layer)
105105
output = next(iter(ds)).numpy()
106106
self.assertAllClose(output, np.array([2, 3, 4, 0]))
107+
108+
def test_one_hot_output_with_higher_rank_input(self):
109+
input_data = np.array([[1, 2], [3, 0]])
110+
vocabulary = [1, 2, 3]
111+
layer = layers.IntegerLookup(
112+
vocabulary=vocabulary, output_mode="one_hot"
113+
)
114+
output_data = layer(input_data)
115+
self.assertEqual(output_data.shape, (2, 2, 4))
116+
expected_output = np.array(
117+
[
118+
[[0, 1, 0, 0], [0, 0, 1, 0]],
119+
[[0, 0, 0, 1], [1, 0, 0, 0]],
120+
]
121+
)
122+
self.assertAllClose(output_data, expected_output)
123+
output_data_3d = layer(np.expand_dims(input_data, axis=0))
124+
self.assertEqual(output_data_3d.shape, (1, 2, 2, 4))
125+
self.assertAllClose(
126+
output_data_3d, np.expand_dims(expected_output, axis=0)
127+
)
128+
129+
def test_multi_hot_output_shape(self):
130+
input_data = np.array([[1, 2], [3, 0]])
131+
vocabulary = [1, 2, 3]
132+
layer = layers.IntegerLookup(
133+
vocabulary=vocabulary, output_mode="multi_hot"
134+
)
135+
output_data = layer(input_data)
136+
self.assertEqual(output_data.shape, (2, 4))
137+
138+
def test_count_output_shape(self):
139+
input_data = np.array([[1, 2], [3, 0]])
140+
vocabulary = [1, 2, 3]
141+
layer = layers.IntegerLookup(vocabulary=vocabulary, output_mode="count")
142+
output_data = layer(input_data)
143+
self.assertEqual(output_data.shape, (2, 4))
144+
145+
def test_tf_idf_output_shape(self):
146+
input_data = np.array([[1, 2], [3, 0]])
147+
vocabulary = [1, 2, 3]
148+
idf_weights = [1.0, 1.0, 1.0]
149+
layer = layers.IntegerLookup(
150+
vocabulary=vocabulary,
151+
idf_weights=idf_weights,
152+
output_mode="tf_idf",
153+
)
154+
output_data = layer(input_data)
155+
self.assertEqual(output_data.shape, (2, 4))

0 commit comments

Comments
 (0)