Skip to content

Commit 9b4e61b

Browse files
Add more scoring tests and rename error counting flag
1 parent fb0f745 commit 9b4e61b

File tree

3 files changed

+87
-28
lines changed

3 files changed

+87
-28
lines changed

seqscore/conll.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -514,7 +514,7 @@ def score_conll_files(
514514
)
515515

516516
class_scores, acc_scores = compute_scores(
517-
pred_docs, ref_docs, count_fp_fn=error_counts
517+
pred_docs, ref_docs, count_fp_fn_examples=error_counts
518518
)
519519
all_class_scores.append(class_scores)
520520
all_acc_scores.append(class_scores)

seqscore/scoring.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,7 @@ def compute_scores(
133133
pred_docs: Sequence[Sequence[LabeledSequence]],
134134
ref_docs: Sequence[Sequence[LabeledSequence]],
135135
*,
136-
count_fp_fn: bool = False,
136+
count_fp_fn_examples: bool = False,
137137
) -> tuple[ClassificationScore, AccuracyScore]:
138138
accuracy = AccuracyScore()
139139
classification = ClassificationScore()
@@ -174,7 +174,7 @@ def compute_scores(
174174
ref_sequence.mentions,
175175
classification,
176176
tokens=ref_sequence.tokens,
177-
count_fp_fn=count_fp_fn,
177+
count_fp_fn_examples=count_fp_fn_examples,
178178
)
179179

180180
return classification, accuracy
@@ -205,13 +205,18 @@ def score_sequence_mentions(
205205
score: ClassificationScore,
206206
*,
207207
tokens: Optional[Sequence[str]] = (),
208-
count_fp_fn: bool = False,
208+
count_fp_fn_examples: bool = False,
209209
) -> None:
210210
"""Update a ClassificationScore for a single sequence's mentions.
211211
212212
Since mentions are defined per-sequence, the behavior is not defined
213-
if you provide mentions corresponding to multiple sequences.
213+
if you provide mentions corresponding to multiple sequences. Tokens
214+
must be provided if you want false positives and negative examples
215+
to be counted.
214216
"""
217+
if count_fp_fn_examples and not tokens:
218+
raise ValueError("Tokens must be provided to count false positive/negative examples")
219+
215220
# Compute span accuracy
216221
pred_mentions_set = set(pred_mentions)
217222
ref_mentions_set = set(ref_mentions)
@@ -226,7 +231,7 @@ def score_sequence_mentions(
226231
# False positive
227232
score.false_pos += 1
228233
score.type_scores[pred.type].false_pos += 1
229-
if count_fp_fn:
234+
if count_fp_fn_examples:
230235
error_tokens = tokens[pred.span.start : pred.span.end]
231236
score.count_false_positive(error_tokens, pred.type)
232237

@@ -235,7 +240,7 @@ def score_sequence_mentions(
235240
if ref not in pred_mentions_set:
236241
score.false_neg += 1
237242
score.type_scores[ref.type].false_neg += 1
238-
if count_fp_fn:
243+
if count_fp_fn_examples:
239244
error_tokens = tokens[ref.span.start : ref.span.end]
240245
score.count_false_negative(error_tokens, ref.type)
241246

tests/test_scoring.py

Lines changed: 75 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from collections import Counter
12
from decimal import Decimal
23

34
import pytest
@@ -8,6 +9,7 @@
89
AccuracyScore,
910
ClassificationScore,
1011
TokenCountError,
12+
TokensWithType,
1113
compute_scores,
1214
convert_score,
1315
score_label_sequences,
@@ -45,7 +47,7 @@ def test_score_sentence_labels_invalid() -> None:
4547
score_sequence_label_accuracy(pred_labels, ref_labels, AccuracyScore())
4648

4749

48-
def test_score_sentence_mentions_correct() -> None:
50+
def test_score_sequence_mentions_correct() -> None:
4951
ref_mentions = [Mention(Span(0, 2), "PER"), Mention(Span(4, 5), "ORG")]
5052
pred_mentions = [Mention(Span(0, 2), "PER"), Mention(Span(4, 5), "ORG")]
5153
score = ClassificationScore()
@@ -63,8 +65,14 @@ def test_score_sentence_mentions_correct() -> None:
6365
assert score.recall == 1.0
6466
assert score.f1 == 1.0
6567

68+
# Test that tokens are required for counting FP/FN
69+
with pytest.raises(ValueError):
70+
score_sequence_mentions(
71+
pred_mentions, ref_mentions, score, count_fp_fn_examples=True
72+
)
73+
6674

67-
def test_score_sentence_mentions_incorrect1() -> None:
75+
def test_score_sequence_mentions_incorrect1() -> None:
6876
ref_mentions = [
6977
Mention(Span(0, 2), "LOC"),
7078
Mention(Span(4, 5), "PER"),
@@ -100,6 +108,28 @@ def test_score_sentence_mentions_incorrect1() -> None:
100108
2 * (score.precision * score.recall) / (score.precision + score.recall)
101109
)
102110

111+
# Run again and check counted fp/fn examples. We do this in a second pass so
112+
# we can cover both True/False cases for count_fp_fn_examples.
113+
score2 = ClassificationScore()
114+
tokens = ["a", "b", "c", "d", "e", "f", "g", "h", "i", "j", "k", "l"]
115+
score_sequence_mentions(
116+
pred_mentions, ref_mentions, score2, count_fp_fn_examples=True, tokens=tokens
117+
)
118+
expected_false_pos = Counter(
119+
[
120+
TokensWithType(("a", "b"), "ORG"),
121+
TokensWithType(("g",), "SPURIOUS"),
122+
]
123+
)
124+
expected_false_neg = Counter(
125+
[
126+
TokensWithType(("a", "b"), "LOC"),
127+
TokensWithType(("h",), "MISC"),
128+
]
129+
)
130+
assert score2.false_pos_examples == expected_false_pos
131+
assert score2.false_neg_examples == expected_false_neg
132+
103133

104134
def test_score_label_sequences_correct() -> None:
105135
ref_labels = [["O", "B-ORG", "I-ORG", "O"], ["B-PER", "I-PER"]]
@@ -192,60 +222,84 @@ def test_accuracy_score_empty() -> None:
192222
assert score.accuracy == 0.0
193223

194224

225+
def test_compute_scores() -> None:
226+
ref_labels = ("O", "B-ORG", "I-ORG", "O", "B-LOC")
227+
ref_mentions = (
228+
Mention(Span(1, 3), "ORG"),
229+
Mention(Span(4, 5), "LOC"),
230+
)
231+
pred_labels = ("O", "B-ORG", "I-ORG", "O", "B-ORG")
232+
pred_mentions = (
233+
Mention(Span(1, 3), "ORG"),
234+
Mention(Span(4, 5), "ORG"),
235+
)
236+
tokens = ("a", "b", "c", "d", "e")
237+
ref_sequence = LabeledSequence(tokens, ref_labels, ref_mentions)
238+
pred_sequence = LabeledSequence(tokens, pred_labels, pred_mentions)
239+
class_score, acc_score = compute_scores([[pred_sequence]], [[ref_sequence]])
240+
assert acc_score.accuracy == 4 / 5
241+
print(class_score)
242+
assert class_score.true_pos == 1
243+
assert class_score.false_pos == 1
244+
assert class_score.false_neg == 1
245+
246+
195247
def test_token_count_error() -> None:
196-
ref_labels = ["O", "B-ORG", "I-ORG", "O"]
197-
pred_labels = ["O", "B-ORG", "I-ORG", "O", "O"]
248+
ref_labels = ("O", "B-ORG", "I-ORG", "O")
249+
pred_labels = ("O", "B-ORG", "I-ORG", "O", "O")
198250
ref_sequence = LabeledSequence(
199-
["a", "b", "c", "d"], ref_labels, provenance=SequenceProvenance(0, "test")
251+
("a", "b", "c", "d"), ref_labels, provenance=SequenceProvenance(0, "test")
200252
)
201253
pred_sequence = LabeledSequence(
202-
["a", "b", "c", "d", "e"], pred_labels, provenance=SequenceProvenance(0, "test")
254+
("a", "b", "c", "d", "e"), pred_labels, provenance=SequenceProvenance(0, "test")
203255
)
204256
with pytest.raises(TokenCountError):
205257
compute_scores([[pred_sequence]], [[ref_sequence]])
206258

207259

208-
def test_provenance_none_raises_error() -> None:
209-
labels = ["O", "B-ORG"]
210-
sequence = LabeledSequence(["a", "b"], labels, provenance=None)
260+
def test_token_count_error_provenance_none_raises_error() -> None:
261+
labels = ("O", "B-ORG")
262+
sequence = LabeledSequence(("a", "b"), labels, provenance=None)
211263
with pytest.raises(ValueError):
212264
TokenCountError.from_predicted_sequence(2, sequence)
213265

214266

215267
def test_differing_num_docs() -> None:
216-
ref_labels = ["O", "B-ORG"]
217-
pred_labels = ["O", "B-LOC"]
268+
ref_labels = ("O", "B-ORG")
269+
pred_labels = ("O", "B-LOC")
270+
tokens = ("a", "b")
218271
ref_sequence = LabeledSequence(
219-
["a", "b"], ref_labels, provenance=SequenceProvenance(0, "test")
272+
tokens, ref_labels, provenance=SequenceProvenance(0, "test")
220273
)
221274
pred_sequence = LabeledSequence(
222-
["a", "b"], pred_labels, provenance=SequenceProvenance(0, "test")
275+
tokens, pred_labels, provenance=SequenceProvenance(0, "test")
223276
)
224277
with pytest.raises(ValueError):
225278
compute_scores([[pred_sequence]], [[ref_sequence], [ref_sequence]])
226279

227280

228281
def test_differing_doc_length() -> None:
229-
ref_labels = ["O", "B-ORG"]
230-
pred_labels = ["O", "B-LOC"]
282+
ref_labels = ("O", "B-ORG")
283+
pred_labels = ("O", "B-LOC")
284+
tokens = ("a", "b")
231285
ref_sequence = LabeledSequence(
232-
["a", "b"], ref_labels, provenance=SequenceProvenance(0, "test")
286+
tokens, ref_labels, provenance=SequenceProvenance(0, "test")
233287
)
234288
pred_sequence = LabeledSequence(
235-
["a", "b"], pred_labels, provenance=SequenceProvenance(0, "test")
289+
tokens, pred_labels, provenance=SequenceProvenance(0, "test")
236290
)
237291
with pytest.raises(ValueError):
238292
compute_scores([[pred_sequence]], [[ref_sequence, ref_sequence]])
239293

240294

241295
def test_differing_pred_and_ref_tokens() -> None:
242-
ref_labels = ["O", "B-ORG"]
243-
pred_labels = ["O", "B-LOC"]
296+
ref_labels = ("O", "B-ORG")
297+
pred_labels = ("O", "B-LOC")
244298
ref_sequence = LabeledSequence(
245-
["a", "b"], ref_labels, provenance=SequenceProvenance(0, "test")
299+
("a", "b"), ref_labels, provenance=SequenceProvenance(0, "test")
246300
)
247301
pred_sequence = LabeledSequence(
248-
["a", "c"], pred_labels, provenance=SequenceProvenance(0, "test")
302+
("a", "c"), pred_labels, provenance=SequenceProvenance(0, "test")
249303
)
250304
with pytest.raises(ValueError):
251305
compute_scores([[pred_sequence]], [[ref_sequence]])

0 commit comments

Comments
 (0)