|
17 | 17 | import tensorflow.compat.v2 as tf
|
18 | 18 |
|
19 | 19 | import itertools
|
| 20 | +import math |
20 | 21 | import os
|
21 | 22 | import random
|
22 | 23 | import string
|
@@ -955,33 +956,67 @@ def test_one_hot_output_shape(self):
|
955 | 956 | outputs = layer(inputs)
|
956 | 957 | self.assertAllEqual(outputs.shape.as_list(), [16, 2])
|
957 | 958 |
|
958 |
| - def test_multi_hot_output_hard_maximum(self): |
959 |
| - """Check binary output when pad_to_max_tokens=True.""" |
960 |
| - vocab_data = ["earth", "wind", "and", "fire"] |
961 |
| - input_array = np.array([["earth", "wind", "and", "fire", ""], |
962 |
| - ["fire", "fire", "and", "earth", "michigan"]]) |
963 |
| - expected_output = [ |
964 |
| - [0, 1, 1, 1, 1, 0], |
965 |
| - [1, 1, 0, 1, 1, 0], |
966 |
| - ] |
| 959 | + @parameterized.product( |
| 960 | + sparse=[True, False], |
| 961 | + adapt=[True, False], |
| 962 | + pad_to_max=[True, False], |
| 963 | + mode=["multi_hot", "count", "tf_idf"], |
| 964 | + ) |
| 965 | + def test_binned_output(self, sparse, adapt, pad_to_max, mode): |
| 966 | + """Check "multi_hot", "count", and "tf_idf" output.""" |
| 967 | + # Adapt breaks ties with sort order. |
| 968 | + vocab_data = ["wind", "fire", "earth", "and"] |
| 969 | + # IDF weight for a term in 1 out of 1 document is log(1 + 1/2). |
| 970 | + idf_data = [math.log(1.5)] * 4 |
| 971 | + input_data = np.array([["and", "earth", "fire", "and", ""], |
| 972 | + ["michigan", "wind", "and", "ohio", ""]]) |
| 973 | + |
| 974 | + if mode == "count": |
| 975 | + expected_output = np.array([ |
| 976 | + [0, 0, 1, 1, 2], |
| 977 | + [2, 1, 0, 0, 1], |
| 978 | + ]) |
| 979 | + elif mode == "tf_idf": |
| 980 | + expected_output = np.array([ |
| 981 | + [0, 0, 1, 1, 2], |
| 982 | + [2, 1, 0, 0, 1], |
| 983 | + ]) * math.log(1.5) |
| 984 | + else: |
| 985 | + expected_output = np.array([ |
| 986 | + [0, 0, 1, 1, 1], |
| 987 | + [1, 1, 0, 0, 1], |
| 988 | + ]) |
| 989 | + expected_output_shape = [None, 5] |
| 990 | + if pad_to_max: |
| 991 | + expected_output = np.concatenate((expected_output, [[0], [0]]), axis=1) |
| 992 | + expected_output_shape = [None, 6] |
967 | 993 |
|
968 |
| - input_data = keras.Input(shape=(None,), dtype=tf.string) |
| 994 | + inputs = keras.Input(shape=(None,), dtype=tf.string) |
969 | 995 | layer = index_lookup.IndexLookup(
|
970 | 996 | max_tokens=6,
|
971 | 997 | num_oov_indices=1,
|
972 | 998 | mask_token="",
|
973 | 999 | oov_token="[OOV]",
|
974 |
| - output_mode=index_lookup.MULTI_HOT, |
975 |
| - pad_to_max_tokens=True, |
| 1000 | + output_mode=mode, |
| 1001 | + pad_to_max_tokens=pad_to_max, |
| 1002 | + sparse=sparse, |
| 1003 | + vocabulary=None if adapt else vocab_data, |
| 1004 | + idf_weights=None if adapt or mode != "tf_idf" else idf_data, |
976 | 1005 | dtype=tf.string)
|
977 |
| - layer.set_vocabulary(vocab_data) |
978 |
| - binary_data = layer(input_data) |
979 |
| - model = keras.Model(inputs=input_data, outputs=binary_data) |
980 |
| - output_dataset = model.predict(input_array) |
981 |
| - self.assertAllEqual(expected_output, output_dataset) |
| 1006 | + if adapt: |
| 1007 | + layer.adapt(vocab_data) |
| 1008 | + outputs = layer(inputs) |
| 1009 | + model = keras.Model(inputs, outputs) |
| 1010 | + output_data = model.predict(input_data) |
| 1011 | + if sparse: |
| 1012 | + output_data = tf.sparse.to_dense(output_data) |
| 1013 | + # Check output data. |
| 1014 | + self.assertAllClose(expected_output, output_data) |
| 1015 | + # Check symbolic output shape. |
| 1016 | + self.assertAllEqual(expected_output_shape, outputs.shape.as_list()) |
982 | 1017 |
|
983 | 1018 | def test_multi_hot_output_no_oov(self):
|
984 |
| - """Check binary output when pad_to_max_tokens=True.""" |
| 1019 | + """Check multi hot output when num_oov_indices=0.""" |
985 | 1020 | vocab_data = ["earth", "wind", "and", "fire"]
|
986 | 1021 | valid_input = np.array([["earth", "wind", "and", "fire"],
|
987 | 1022 | ["fire", "and", "earth", ""]])
|
@@ -1050,188 +1085,6 @@ def test_multi_hot_output_hard_maximum_multiple_adapts(self):
|
1050 | 1085 | self.assertAllEqual(first_expected_output, first_output)
|
1051 | 1086 | self.assertAllEqual(second_expected_output, second_output)
|
1052 | 1087 |
|
1053 |
| - def test_multi_hot_output_soft_maximum(self): |
1054 |
| - """Check multi_hot output when pad_to_max_tokens=False.""" |
1055 |
| - vocab_data = ["earth", "wind", "and", "fire"] |
1056 |
| - input_array = np.array([["earth", "wind", "and", "fire", ""], |
1057 |
| - ["fire", "and", "earth", "michigan", ""]]) |
1058 |
| - expected_output = [ |
1059 |
| - [0, 1, 1, 1, 1], |
1060 |
| - [1, 1, 0, 1, 1], |
1061 |
| - ] |
1062 |
| - |
1063 |
| - input_data = keras.Input(shape=(None,), dtype=tf.string) |
1064 |
| - layer = index_lookup.IndexLookup( |
1065 |
| - max_tokens=None, |
1066 |
| - num_oov_indices=1, |
1067 |
| - mask_token="", |
1068 |
| - oov_token="[OOV]", |
1069 |
| - output_mode=index_lookup.MULTI_HOT, |
1070 |
| - dtype=tf.string) |
1071 |
| - layer.set_vocabulary(vocab_data) |
1072 |
| - binary_data = layer(input_data) |
1073 |
| - model = keras.Model(inputs=input_data, outputs=binary_data) |
1074 |
| - output_dataset = model.predict(input_array) |
1075 |
| - self.assertAllEqual(expected_output, output_dataset) |
1076 |
| - |
1077 |
| - def test_multi_hot_output_shape(self): |
1078 |
| - input_data = keras.Input(batch_size=16, shape=(4,), dtype=tf.string) |
1079 |
| - layer = index_lookup.IndexLookup( |
1080 |
| - max_tokens=2, |
1081 |
| - num_oov_indices=1, |
1082 |
| - mask_token="", |
1083 |
| - oov_token="[OOV]", |
1084 |
| - output_mode=index_lookup.MULTI_HOT, |
1085 |
| - vocabulary=["foo"], |
1086 |
| - dtype=tf.string) |
1087 |
| - binary_data = layer(input_data) |
1088 |
| - self.assertAllEqual(binary_data.shape.as_list(), [16, 2]) |
1089 |
| - |
1090 |
| - def test_count_output_hard_maxiumum(self): |
1091 |
| - """Check count output when pad_to_max_tokens=True.""" |
1092 |
| - vocab_data = ["earth", "wind", "and", "fire"] |
1093 |
| - input_array = np.array([["earth", "wind", "and", "wind", ""], |
1094 |
| - ["fire", "fire", "fire", "michigan", ""]]) |
1095 |
| - expected_output = [ |
1096 |
| - [0, 1, 2, 1, 0, 0], |
1097 |
| - [1, 0, 0, 0, 3, 0], |
1098 |
| - ] |
1099 |
| - |
1100 |
| - input_data = keras.Input(shape=(None,), dtype=tf.string) |
1101 |
| - layer = index_lookup.IndexLookup( |
1102 |
| - max_tokens=6, |
1103 |
| - num_oov_indices=1, |
1104 |
| - mask_token="", |
1105 |
| - oov_token="[OOV]", |
1106 |
| - output_mode=index_lookup.COUNT, |
1107 |
| - pad_to_max_tokens=True, |
1108 |
| - dtype=tf.string) |
1109 |
| - layer.set_vocabulary(vocab_data) |
1110 |
| - count_data = layer(input_data) |
1111 |
| - model = keras.Model(inputs=input_data, outputs=count_data) |
1112 |
| - output_dataset = model.predict(input_array) |
1113 |
| - self.assertAllEqual(expected_output, output_dataset) |
1114 |
| - |
1115 |
| - def test_count_output_soft_maximum(self): |
1116 |
| - """Check count output when pad_to_max_tokens=False.""" |
1117 |
| - vocab_data = ["earth", "wind", "and", "fire"] |
1118 |
| - input_array = np.array([["earth", "wind", "and", "wind", ""], |
1119 |
| - ["fire", "fire", "fire", "michigan", ""]]) |
1120 |
| - expected_output = [ |
1121 |
| - [0, 1, 2, 1, 0], |
1122 |
| - [1, 0, 0, 0, 3], |
1123 |
| - ] |
1124 |
| - |
1125 |
| - input_data = keras.Input(shape=(None,), dtype=tf.string) |
1126 |
| - layer = index_lookup.IndexLookup( |
1127 |
| - max_tokens=None, |
1128 |
| - num_oov_indices=1, |
1129 |
| - mask_token="", |
1130 |
| - oov_token="[OOV]", |
1131 |
| - output_mode=index_lookup.COUNT, |
1132 |
| - dtype=tf.string) |
1133 |
| - layer.set_vocabulary(vocab_data) |
1134 |
| - count_data = layer(input_data) |
1135 |
| - model = keras.Model(inputs=input_data, outputs=count_data) |
1136 |
| - output_dataset = model.predict(input_array) |
1137 |
| - self.assertAllEqual(expected_output, output_dataset) |
1138 |
| - |
1139 |
| - def test_count_output_shape(self): |
1140 |
| - input_data = keras.Input(batch_size=16, shape=(4,), dtype=tf.string) |
1141 |
| - layer = index_lookup.IndexLookup( |
1142 |
| - max_tokens=2, |
1143 |
| - num_oov_indices=1, |
1144 |
| - mask_token="", |
1145 |
| - oov_token="[OOV]", |
1146 |
| - output_mode=index_lookup.COUNT, |
1147 |
| - vocabulary=["foo"], |
1148 |
| - dtype=tf.string) |
1149 |
| - count_data = layer(input_data) |
1150 |
| - self.assertAllEqual(count_data.shape.as_list(), [16, 2]) |
1151 |
| - |
1152 |
| - @parameterized.named_parameters( |
1153 |
| - ("sparse", True), |
1154 |
| - ("dense", False), |
1155 |
| - ) |
1156 |
| - def test_ifidf_output_hard_maximum(self, sparse): |
1157 |
| - """Check tf-idf output when pad_to_max_tokens=True.""" |
1158 |
| - vocab_data = ["earth", "wind", "and", "fire"] |
1159 |
| - # OOV idf weight (bucket 0) should 0.5, the average of passed weights. |
1160 |
| - idf_weights = [.4, .25, .75, .6] |
1161 |
| - input_array = np.array([["earth", "wind", "and", "earth", ""], |
1162 |
| - ["ohio", "fire", "earth", "michigan", ""]]) |
1163 |
| - expected_output = [ |
1164 |
| - [0.00, 0.80, 0.25, 0.75, 0.00, 0.00], |
1165 |
| - [1.00, 0.40, 0.00, 0.00, 0.60, 0.00], |
1166 |
| - ] |
1167 |
| - |
1168 |
| - input_data = keras.Input(shape=(None,), dtype=tf.string) |
1169 |
| - layer = index_lookup.IndexLookup( |
1170 |
| - max_tokens=6, |
1171 |
| - num_oov_indices=1, |
1172 |
| - mask_token="", |
1173 |
| - oov_token="[OOV]", |
1174 |
| - output_mode=index_lookup.TF_IDF, |
1175 |
| - pad_to_max_tokens=True, |
1176 |
| - dtype=tf.string, |
1177 |
| - sparse=sparse, |
1178 |
| - vocabulary=vocab_data, |
1179 |
| - idf_weights=idf_weights) |
1180 |
| - layer_output = layer(input_data) |
1181 |
| - model = keras.Model(inputs=input_data, outputs=layer_output) |
1182 |
| - output_dataset = model.predict(input_array) |
1183 |
| - if sparse: |
1184 |
| - output_dataset = tf.sparse.to_dense(output_dataset) |
1185 |
| - self.assertAllClose(expected_output, output_dataset) |
1186 |
| - |
1187 |
| - @parameterized.named_parameters( |
1188 |
| - ("sparse", True), |
1189 |
| - ("dense", False), |
1190 |
| - ) |
1191 |
| - def test_ifidf_output_soft_maximum(self, sparse): |
1192 |
| - """Check tf-idf output when pad_to_max_tokens=False.""" |
1193 |
| - vocab_data = ["earth", "wind", "and", "fire"] |
1194 |
| - # OOV idf weight (bucket 0) should 0.5, the average of passed weights. |
1195 |
| - idf_weights = [.4, .25, .75, .6] |
1196 |
| - input_array = np.array([["earth", "wind", "and", "earth", ""], |
1197 |
| - ["ohio", "fire", "earth", "michigan", ""]]) |
1198 |
| - expected_output = [ |
1199 |
| - [0.00, 0.80, 0.25, 0.75, 0.00], |
1200 |
| - [1.00, 0.40, 0.00, 0.00, 0.60], |
1201 |
| - ] |
1202 |
| - |
1203 |
| - input_data = keras.Input(shape=(None,), dtype=tf.string) |
1204 |
| - layer = index_lookup.IndexLookup( |
1205 |
| - max_tokens=None, |
1206 |
| - num_oov_indices=1, |
1207 |
| - mask_token="", |
1208 |
| - oov_token="[OOV]", |
1209 |
| - output_mode=index_lookup.TF_IDF, |
1210 |
| - dtype=tf.string, |
1211 |
| - sparse=sparse, |
1212 |
| - vocabulary=vocab_data, |
1213 |
| - idf_weights=idf_weights) |
1214 |
| - layer_output = layer(input_data) |
1215 |
| - model = keras.Model(inputs=input_data, outputs=layer_output) |
1216 |
| - output_dataset = model.predict(input_array) |
1217 |
| - if sparse: |
1218 |
| - output_dataset = tf.sparse.to_dense(output_dataset) |
1219 |
| - self.assertAllClose(expected_output, output_dataset) |
1220 |
| - |
1221 |
| - def test_ifidf_output_shape(self): |
1222 |
| - input_data = keras.Input(batch_size=16, shape=(4,), dtype=tf.string) |
1223 |
| - layer = index_lookup.IndexLookup( |
1224 |
| - max_tokens=2, |
1225 |
| - num_oov_indices=1, |
1226 |
| - mask_token="", |
1227 |
| - oov_token="[OOV]", |
1228 |
| - output_mode=index_lookup.TF_IDF, |
1229 |
| - dtype=tf.string, |
1230 |
| - vocabulary=["foo"], |
1231 |
| - idf_weights=[1.0]) |
1232 |
| - layer_output = layer(input_data) |
1233 |
| - self.assertAllEqual(layer_output.shape.as_list(), [16, 2]) |
1234 |
| - |
1235 | 1088 | def test_int_output_file_vocab(self):
|
1236 | 1089 | vocab_data = ["earth", "wind", "and", "fire"]
|
1237 | 1090 | input_array = np.array([["earth", "wind", "and", "fire"],
|
|
0 commit comments