Skip to content
This repository was archived by the owner on Jun 3, 2025. It is now read-only.

Commit 01a427a

Browse files
authored
Update label mapping for deepsparse.transformers.eval_downstream (#323) (#324)
* Update label mapping for deepsparse.transformers.eval_downstream * Fix MNLI as well
1 parent 500d132 commit 01a427a

File tree

1 file changed

+13
-6
lines changed

1 file changed

+13
-6
lines changed

src/deepsparse/transformers/eval_downstream.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -123,20 +123,22 @@ def mnli_eval(args):
123123
)
124124
print(f"Engine info: {text_classify.model}")
125125

126+
label_map = {"entailment": 0, "neutral": 1, "contradiction": 2}
127+
126128
for idx, sample in enumerate(tqdm(mnli_matched)):
127-
pred = text_classify(sample["premise"], sample["hypothesis"])
129+
pred = text_classify([[sample["premise"], sample["hypothesis"]]])
128130
mnli_metrics.add_batch(
129-
predictions=[int(pred[0]["label"].split("_")[-1])],
131+
predictions=[label_map.get(pred[0]["label"])],
130132
references=[sample["label"]],
131133
)
132134

133135
if args.max_samples and idx > args.max_samples:
134136
break
135137

136138
for idx, sample in enumerate(tqdm(mnli_mismatched)):
137-
pred = text_classify(sample["premise"], sample["hypothesis"])
139+
pred = text_classify([[sample["premise"], sample["hypothesis"]]])
138140
mnli_metrics.add_batch(
139-
predictions=[int(pred[0]["label"].split("_")[-1])],
141+
predictions=[label_map.get(pred[0]["label"])],
140142
references=[sample["label"]],
141143
)
142144

@@ -161,11 +163,13 @@ def qqp_eval(args):
161163
)
162164
print(f"Engine info: {text_classify.model}")
163165

166+
label_map = {"not_duplicate": 0, "duplicate": 1}
167+
164168
for idx, sample in enumerate(tqdm(qqp)):
165169
pred = text_classify([[sample["question1"], sample["question2"]]])
166170

167171
qqp_metrics.add_batch(
168-
predictions=[int(pred[0]["label"].split("_")[-1])],
172+
predictions=[label_map.get(pred[0]["label"])],
169173
references=[sample["label"]],
170174
)
171175

@@ -190,13 +194,15 @@ def sst2_eval(args):
190194
)
191195
print(f"Engine info: {text_classify.model}")
192196

197+
label_map = {"negative": 0, "positive": 1}
198+
193199
for idx, sample in enumerate(tqdm(sst2)):
194200
pred = text_classify(
195201
sample["sentence"],
196202
)
197203

198204
sst2_metrics.add_batch(
199-
predictions=[int(pred[0]["label"].split("_")[-1])],
205+
predictions=[label_map.get(pred[0]["label"])],
200206
references=[sample["label"]],
201207
)
202208

@@ -229,6 +235,7 @@ def parse_args():
229235
"--dataset",
230236
type=str,
231237
choices=list(SUPPORTED_DATASETS.keys()),
238+
required=True,
232239
)
233240

234241
parser.add_argument(

0 commit comments

Comments
 (0)