@@ -123,20 +123,22 @@ def mnli_eval(args):
123123 )
124124 print (f"Engine info: { text_classify .model } " )
125125
126+ label_map = {"entailment" : 0 , "neutral" : 1 , "contradiction" : 2 }
127+
126128 for idx , sample in enumerate (tqdm (mnli_matched )):
127- pred = text_classify (sample ["premise" ], sample ["hypothesis" ])
129+ pred = text_classify ([[ sample ["premise" ], sample ["hypothesis" ]] ])
128130 mnli_metrics .add_batch (
129- predictions = [int (pred [0 ]["label" ]. split ( "_" )[ - 1 ])],
131+ predictions = [label_map . get (pred [0 ]["label" ])],
130132 references = [sample ["label" ]],
131133 )
132134
133135 if args .max_samples and idx > args .max_samples :
134136 break
135137
136138 for idx , sample in enumerate (tqdm (mnli_mismatched )):
137- pred = text_classify (sample ["premise" ], sample ["hypothesis" ])
139+ pred = text_classify ([[ sample ["premise" ], sample ["hypothesis" ]] ])
138140 mnli_metrics .add_batch (
139- predictions = [int (pred [0 ]["label" ]. split ( "_" )[ - 1 ])],
141+ predictions = [label_map . get (pred [0 ]["label" ])],
140142 references = [sample ["label" ]],
141143 )
142144
@@ -161,11 +163,13 @@ def qqp_eval(args):
161163 )
162164 print (f"Engine info: { text_classify .model } " )
163165
166+ label_map = {"not_duplicate" : 0 , "duplicate" : 1 }
167+
164168 for idx , sample in enumerate (tqdm (qqp )):
165169 pred = text_classify ([[sample ["question1" ], sample ["question2" ]]])
166170
167171 qqp_metrics .add_batch (
168- predictions = [int (pred [0 ]["label" ]. split ( "_" )[ - 1 ])],
172+ predictions = [label_map . get (pred [0 ]["label" ])],
169173 references = [sample ["label" ]],
170174 )
171175
@@ -190,13 +194,15 @@ def sst2_eval(args):
190194 )
191195 print (f"Engine info: { text_classify .model } " )
192196
197+ label_map = {"negative" : 0 , "positive" : 1 }
198+
193199 for idx , sample in enumerate (tqdm (sst2 )):
194200 pred = text_classify (
195201 sample ["sentence" ],
196202 )
197203
198204 sst2_metrics .add_batch (
199- predictions = [int (pred [0 ]["label" ]. split ( "_" )[ - 1 ])],
205+ predictions = [label_map . get (pred [0 ]["label" ])],
200206 references = [sample ["label" ]],
201207 )
202208
@@ -229,6 +235,7 @@ def parse_args():
229235 "--dataset" ,
230236 type = str ,
231237 choices = list (SUPPORTED_DATASETS .keys ()),
238+ required = True ,
232239 )
233240
234241 parser .add_argument (
0 commit comments