@@ -104,3 +104,52 @@ def test_tf_data_compatibility(self):
104
104
ds = tf_data .Dataset .from_tensor_slices (input_data ).batch (4 ).map (layer )
105
105
output = next (iter (ds )).numpy ()
106
106
self .assertAllClose (output , np .array ([2 , 3 , 4 , 0 ]))
107
+
108
+ def test_one_hot_output_with_higher_rank_input (self ):
109
+ input_data = np .array ([[1 , 2 ], [3 , 0 ]])
110
+ vocabulary = [1 , 2 , 3 ]
111
+ layer = layers .IntegerLookup (
112
+ vocabulary = vocabulary , output_mode = "one_hot"
113
+ )
114
+ output_data = layer (input_data )
115
+ self .assertEqual (output_data .shape , (2 , 2 , 4 ))
116
+ expected_output = np .array (
117
+ [
118
+ [[0 , 1 , 0 , 0 ], [0 , 0 , 1 , 0 ]],
119
+ [[0 , 0 , 0 , 1 ], [1 , 0 , 0 , 0 ]],
120
+ ]
121
+ )
122
+ self .assertAllClose (output_data , expected_output )
123
+ output_data_3d = layer (np .expand_dims (input_data , axis = 0 ))
124
+ self .assertEqual (output_data_3d .shape , (1 , 2 , 2 , 4 ))
125
+ self .assertAllClose (
126
+ output_data_3d , np .expand_dims (expected_output , axis = 0 )
127
+ )
128
+
129
+ def test_multi_hot_output_shape (self ):
130
+ input_data = np .array ([[1 , 2 ], [3 , 0 ]])
131
+ vocabulary = [1 , 2 , 3 ]
132
+ layer = layers .IntegerLookup (
133
+ vocabulary = vocabulary , output_mode = "multi_hot"
134
+ )
135
+ output_data = layer (input_data )
136
+ self .assertEqual (output_data .shape , (2 , 4 ))
137
+
138
+ def test_count_output_shape (self ):
139
+ input_data = np .array ([[1 , 2 ], [3 , 0 ]])
140
+ vocabulary = [1 , 2 , 3 ]
141
+ layer = layers .IntegerLookup (vocabulary = vocabulary , output_mode = "count" )
142
+ output_data = layer (input_data )
143
+ self .assertEqual (output_data .shape , (2 , 4 ))
144
+
145
+ def test_tf_idf_output_shape (self ):
146
+ input_data = np .array ([[1 , 2 ], [3 , 0 ]])
147
+ vocabulary = [1 , 2 , 3 ]
148
+ idf_weights = [1.0 , 1.0 , 1.0 ]
149
+ layer = layers .IntegerLookup (
150
+ vocabulary = vocabulary ,
151
+ idf_weights = idf_weights ,
152
+ output_mode = "tf_idf" ,
153
+ )
154
+ output_data = layer (input_data )
155
+ self .assertEqual (output_data .shape , (2 , 4 ))
0 commit comments