1+ from collections import Counter
12from decimal import Decimal
23
34import pytest
89 AccuracyScore ,
910 ClassificationScore ,
1011 TokenCountError ,
12+ TokensWithType ,
1113 compute_scores ,
1214 convert_score ,
1315 score_label_sequences ,
@@ -45,7 +47,7 @@ def test_score_sentence_labels_invalid() -> None:
4547 score_sequence_label_accuracy (pred_labels , ref_labels , AccuracyScore ())
4648
4749
48- def test_score_sentence_mentions_correct () -> None :
50+ def test_score_sequence_mentions_correct () -> None :
4951 ref_mentions = [Mention (Span (0 , 2 ), "PER" ), Mention (Span (4 , 5 ), "ORG" )]
5052 pred_mentions = [Mention (Span (0 , 2 ), "PER" ), Mention (Span (4 , 5 ), "ORG" )]
5153 score = ClassificationScore ()
@@ -63,8 +65,14 @@ def test_score_sentence_mentions_correct() -> None:
6365 assert score .recall == 1.0
6466 assert score .f1 == 1.0
6567
68+ # Test that tokens are required for counting FP/FN
69+ with pytest .raises (ValueError ):
70+ score_sequence_mentions (
71+ pred_mentions , ref_mentions , score , count_fp_fn_examples = True
72+ )
73+
6674
67- def test_score_sentence_mentions_incorrect1 () -> None :
75+ def test_score_sequence_mentions_incorrect1 () -> None :
6876 ref_mentions = [
6977 Mention (Span (0 , 2 ), "LOC" ),
7078 Mention (Span (4 , 5 ), "PER" ),
@@ -100,6 +108,28 @@ def test_score_sentence_mentions_incorrect1() -> None:
100108 2 * (score .precision * score .recall ) / (score .precision + score .recall )
101109 )
102110
111+ # Run again and check counted fp/fn examples. We do this in a second pass so
112+ # we can cover both True/False cases for count_fp_fn_examples.
113+ score2 = ClassificationScore ()
114+ tokens = ["a" , "b" , "c" , "d" , "e" , "f" , "g" , "h" , "i" , "j" , "k" , "l" ]
115+ score_sequence_mentions (
116+ pred_mentions , ref_mentions , score2 , count_fp_fn_examples = True , tokens = tokens
117+ )
118+ expected_false_pos = Counter (
119+ [
120+ TokensWithType (("a" , "b" ), "ORG" ),
121+ TokensWithType (("g" ,), "SPURIOUS" ),
122+ ]
123+ )
124+ expected_false_neg = Counter (
125+ [
126+ TokensWithType (("a" , "b" ), "LOC" ),
127+ TokensWithType (("h" ,), "MISC" ),
128+ ]
129+ )
130+ assert score2 .false_pos_examples == expected_false_pos
131+ assert score2 .false_neg_examples == expected_false_neg
132+
103133
104134def test_score_label_sequences_correct () -> None :
105135 ref_labels = [["O" , "B-ORG" , "I-ORG" , "O" ], ["B-PER" , "I-PER" ]]
@@ -192,60 +222,84 @@ def test_accuracy_score_empty() -> None:
192222 assert score .accuracy == 0.0
193223
194224
225+ def test_compute_scores () -> None :
226+ ref_labels = ("O" , "B-ORG" , "I-ORG" , "O" , "B-LOC" )
227+ ref_mentions = (
228+ Mention (Span (1 , 3 ), "ORG" ),
229+ Mention (Span (4 , 5 ), "LOC" ),
230+ )
231+ pred_labels = ("O" , "B-ORG" , "I-ORG" , "O" , "B-ORG" )
232+ pred_mentions = (
233+ Mention (Span (1 , 3 ), "ORG" ),
234+ Mention (Span (4 , 5 ), "ORG" ),
235+ )
236+ tokens = ("a" , "b" , "c" , "d" , "e" )
237+ ref_sequence = LabeledSequence (tokens , ref_labels , ref_mentions )
238+ pred_sequence = LabeledSequence (tokens , pred_labels , pred_mentions )
239+ class_score , acc_score = compute_scores ([[pred_sequence ]], [[ref_sequence ]])
240+ assert acc_score .accuracy == 4 / 5
241+ print (class_score )
242+ assert class_score .true_pos == 1
243+ assert class_score .false_pos == 1
244+ assert class_score .false_neg == 1
245+
246+
195247def test_token_count_error () -> None :
196- ref_labels = [ "O" , "B-ORG" , "I-ORG" , "O" ]
197- pred_labels = [ "O" , "B-ORG" , "I-ORG" , "O" , "O" ]
248+ ref_labels = ( "O" , "B-ORG" , "I-ORG" , "O" )
249+ pred_labels = ( "O" , "B-ORG" , "I-ORG" , "O" , "O" )
198250 ref_sequence = LabeledSequence (
199- [ "a" , "b" , "c" , "d" ] , ref_labels , provenance = SequenceProvenance (0 , "test" )
251+ ( "a" , "b" , "c" , "d" ) , ref_labels , provenance = SequenceProvenance (0 , "test" )
200252 )
201253 pred_sequence = LabeledSequence (
202- [ "a" , "b" , "c" , "d" , "e" ] , pred_labels , provenance = SequenceProvenance (0 , "test" )
254+ ( "a" , "b" , "c" , "d" , "e" ) , pred_labels , provenance = SequenceProvenance (0 , "test" )
203255 )
204256 with pytest .raises (TokenCountError ):
205257 compute_scores ([[pred_sequence ]], [[ref_sequence ]])
206258
207259
208- def test_provenance_none_raises_error () -> None :
209- labels = [ "O" , "B-ORG" ]
210- sequence = LabeledSequence ([ "a" , "b" ] , labels , provenance = None )
260+ def test_token_count_error_provenance_none_raises_error () -> None :
261+ labels = ( "O" , "B-ORG" )
262+ sequence = LabeledSequence (( "a" , "b" ) , labels , provenance = None )
211263 with pytest .raises (ValueError ):
212264 TokenCountError .from_predicted_sequence (2 , sequence )
213265
214266
215267def test_differing_num_docs () -> None :
216- ref_labels = ["O" , "B-ORG" ]
217- pred_labels = ["O" , "B-LOC" ]
268+ ref_labels = ("O" , "B-ORG" )
269+ pred_labels = ("O" , "B-LOC" )
270+ tokens = ("a" , "b" )
218271 ref_sequence = LabeledSequence (
219- [ "a" , "b" ] , ref_labels , provenance = SequenceProvenance (0 , "test" )
272+ tokens , ref_labels , provenance = SequenceProvenance (0 , "test" )
220273 )
221274 pred_sequence = LabeledSequence (
222- [ "a" , "b" ] , pred_labels , provenance = SequenceProvenance (0 , "test" )
275+ tokens , pred_labels , provenance = SequenceProvenance (0 , "test" )
223276 )
224277 with pytest .raises (ValueError ):
225278 compute_scores ([[pred_sequence ]], [[ref_sequence ], [ref_sequence ]])
226279
227280
228281def test_differing_doc_length () -> None :
229- ref_labels = ["O" , "B-ORG" ]
230- pred_labels = ["O" , "B-LOC" ]
282+ ref_labels = ("O" , "B-ORG" )
283+ pred_labels = ("O" , "B-LOC" )
284+ tokens = ("a" , "b" )
231285 ref_sequence = LabeledSequence (
232- [ "a" , "b" ] , ref_labels , provenance = SequenceProvenance (0 , "test" )
286+ tokens , ref_labels , provenance = SequenceProvenance (0 , "test" )
233287 )
234288 pred_sequence = LabeledSequence (
235- [ "a" , "b" ] , pred_labels , provenance = SequenceProvenance (0 , "test" )
289+ tokens , pred_labels , provenance = SequenceProvenance (0 , "test" )
236290 )
237291 with pytest .raises (ValueError ):
238292 compute_scores ([[pred_sequence ]], [[ref_sequence , ref_sequence ]])
239293
240294
241295def test_differing_pred_and_ref_tokens () -> None :
242- ref_labels = [ "O" , "B-ORG" ]
243- pred_labels = [ "O" , "B-LOC" ]
296+ ref_labels = ( "O" , "B-ORG" )
297+ pred_labels = ( "O" , "B-LOC" )
244298 ref_sequence = LabeledSequence (
245- [ "a" , "b" ] , ref_labels , provenance = SequenceProvenance (0 , "test" )
299+ ( "a" , "b" ) , ref_labels , provenance = SequenceProvenance (0 , "test" )
246300 )
247301 pred_sequence = LabeledSequence (
248- [ "a" , "c" ] , pred_labels , provenance = SequenceProvenance (0 , "test" )
302+ ( "a" , "c" ) , pred_labels , provenance = SequenceProvenance (0 , "test" )
249303 )
250304 with pytest .raises (ValueError ):
251305 compute_scores ([[pred_sequence ]], [[ref_sequence ]])
0 commit comments