diff --git a/conlleval.py b/conlleval.py new file mode 100644 index 0000000..b05c527 --- /dev/null +++ b/conlleval.py @@ -0,0 +1,235 @@ +""" +This script applies to IOB2 or IOBES tagging scheme. +If you are using a different scheme, please convert to IOB2 or IOBES. + +IOB2: +- B = begin, +- I = inside but not the first, +- O = outside + +e.g. +John lives in New York City . +B-PER O O B-LOC I-LOC I-LOC O + +IOBES: +- B = begin, +- E = end, +- S = singleton, +- I = inside but not the first or the last, +- O = outside + +e.g. +John lives in New York City . +S-PER O O B-LOC I-LOC E-LOC O + +prefix: IOBES +chunk_type: PER, LOC, etc. +""" +from __future__ import division, print_function, unicode_literals + +import sys +from collections import defaultdict + +def split_tag(chunk_tag): + """ + split chunk tag into IOBES prefix and chunk_type + e.g. + B-PER -> (B, PER) + O -> (O, None) + """ + if chunk_tag == 'O': + return ('O', None) + return chunk_tag.split('-', maxsplit=1) + +def is_chunk_end(prev_tag, tag): + """ + check if the previous chunk ended between the previous and current word + e.g. + (B-PER, I-PER) -> False + (B-LOC, O) -> True + + Note: in case of contradicting tags, e.g. (B-PER, I-LOC) + this is considered as (B-PER, B-LOC) + """ + prefix1, chunk_type1 = split_tag(prev_tag) + prefix2, chunk_type2 = split_tag(tag) + + if prefix1 == 'O': + return False + if prefix2 == 'O': + return prefix1 != 'O' + + if chunk_type1 != chunk_type2: + return True + + return prefix2 in ['B', 'S'] or prefix1 in ['E', 'S'] + +def is_chunk_start(prev_tag, tag): + """ + check if a new chunk started between the previous and current word + """ + prefix1, chunk_type1 = split_tag(prev_tag) + prefix2, chunk_type2 = split_tag(tag) + + if prefix2 == 'O': + return False + if prefix1 == 'O': + return prefix2 != 'O' + + if chunk_type1 != chunk_type2: + return True + + return prefix2 in ['B', 'S'] or prefix1 in ['E', 'S'] + + +def calc_metrics(tp, p, t, percent=True): + """ + compute overall precision, recall and FB1 (default values are 0.0) + if percent is True, return 100 * original decimal value + """ + precision = tp / p if p else 0 + recall = tp / t if t else 0 + fb1 = 2 * precision * recall / (precision + recall) if precision + recall else 0 + if percent: + return 100 * precision, 100 * recall, 100 * fb1 + else: + return precision, recall, fb1 + + +def count_chunks(true_seqs, pred_seqs): + """ + true_seqs: a list of true tags + pred_seqs: a list of predicted tags + + return: + correct_chunks: a dict (counter), + key = chunk types, + value = number of correctly identified chunks per type + true_chunks: a dict, number of true chunks per type + pred_chunks: a dict, number of identified chunks per type + + correct_counts, true_counts, pred_counts: similar to above, but for tags + """ + correct_chunks = defaultdict(int) + true_chunks = defaultdict(int) + pred_chunks = defaultdict(int) + + correct_counts = defaultdict(int) + true_counts = defaultdict(int) + pred_counts = defaultdict(int) + + prev_true_tag, prev_pred_tag = 'O', 'O' + correct_chunk = None + + for true_tag, pred_tag in zip(true_seqs, pred_seqs): + if true_tag == pred_tag: + correct_counts[true_tag] += 1 + true_counts[true_tag] += 1 + pred_counts[pred_tag] += 1 + + _, true_type = split_tag(true_tag) + _, pred_type = split_tag(pred_tag) + + if correct_chunk is not None: + true_end = is_chunk_end(prev_true_tag, true_tag) + pred_end = is_chunk_end(prev_pred_tag, pred_tag) + + if pred_end and true_end: + correct_chunks[correct_chunk] += 1 + correct_chunk = None + elif pred_end != true_end or true_type != pred_type: + correct_chunk = None + + true_start = is_chunk_start(prev_true_tag, true_tag) + pred_start = is_chunk_start(prev_pred_tag, pred_tag) + + if true_start and pred_start and true_type == pred_type: + correct_chunk = true_type + if true_start: + true_chunks[true_type] += 1 + if pred_start: + pred_chunks[pred_type] += 1 + + prev_true_tag, prev_pred_tag = true_tag, pred_tag + if correct_chunk is not None: + correct_chunks[correct_chunk] += 1 + + return (correct_chunks, true_chunks, pred_chunks, + correct_counts, true_counts, pred_counts) + +def get_result(correct_chunks, true_chunks, pred_chunks, + correct_counts, true_counts, pred_counts, verbose=True): + """ + if verbose, print overall performance, as well as preformance per chunk type; + otherwise, simply return overall prec, rec, f1 scores + """ + # sum counts + sum_correct_chunks = sum(correct_chunks.values()) + sum_true_chunks = sum(true_chunks.values()) + sum_pred_chunks = sum(pred_chunks.values()) + + sum_correct_counts = sum(correct_counts.values()) + sum_true_counts = sum(true_counts.values()) + + nonO_correct_counts = sum(v for k, v in correct_counts.items() if k != 'O') + nonO_true_counts = sum(v for k, v in true_counts.items() if k != 'O') + + chunk_types = sorted(list(set(list(true_chunks) + list(pred_chunks)))) + + # compute overall precision, recall and FB1 (default values are 0.0) + prec, rec, f1 = calc_metrics(sum_correct_chunks, sum_pred_chunks, sum_true_chunks) + res = (prec, rec, f1) + if not verbose: + return res + + # print overall performance, and performance per chunk type + + print("processed %i tokens with %i phrases; " % (sum_true_counts, sum_true_chunks), end='') + print("found: %i phrases; correct: %i.\n" % (sum_pred_chunks, sum_correct_chunks), end='') + + print("accuracy: %6.2f%%; (non-O)" % (100*nonO_correct_counts/nonO_true_counts)) + print("accuracy: %6.2f%%; " % (100*sum_correct_counts/sum_true_counts), end='') + print("precision: %6.2f%%; recall: %6.2f%%; FB1: %6.2f" % (prec, rec, f1)) + + # for each chunk type, compute precision, recall and FB1 (default values are 0.0) + for t in chunk_types: + prec, rec, f1 = calc_metrics(correct_chunks[t], pred_chunks[t], true_chunks[t]) + print("%17s: " %t , end='') + print("precision: %6.2f%%; recall: %6.2f%%; FB1: %6.2f" % + (prec, rec, f1), end='') + print(" %d" % pred_chunks[t]) + + return res + # you can generate LaTeX output for tables like in + # http://cnts.uia.ac.be/conll2003/ner/example.tex + # but I'm not implementing this + +def evaluate(true_seqs, pred_seqs, verbose=True): + (correct_chunks, true_chunks, pred_chunks, + correct_counts, true_counts, pred_counts) = count_chunks(true_seqs, pred_seqs) + result = get_result(correct_chunks, true_chunks, pred_chunks, + correct_counts, true_counts, pred_counts, verbose=verbose) + return result + +def evaluate_conll_file(fileIterator): + true_seqs, pred_seqs = [], [] + + for line in fileIterator: + cols = line.strip().split() + # each non-empty line must contain >= 3 columns + if not cols: + true_seqs.append('O') + pred_seqs.append('O') + elif len(cols) < 3: + raise IOError("conlleval: too few columns in line %s\n" % line) + else: + # extract tags from last 2 columns + true_seqs.append(cols[-2]) + pred_seqs.append(cols[-1]) + return evaluate(true_seqs, pred_seqs) + +if __name__ == '__main__': + """ + usage: conlleval < file + """ + evaluate_conll_file(sys.stdin) diff --git a/train.py b/train.py index 76f2a8f..f6cf9b5 100644 --- a/train.py +++ b/train.py @@ -7,6 +7,7 @@ import os import numpy as np import argparse +from conlleval import evaluate_conll_file def train(model, iterator, optimizer, criterion): model.train() @@ -65,6 +66,10 @@ def eval(model, iterator, f): fout.write(f"{w} {t} {p}\n") fout.write("\n") + with open("temp") as fout: + evaluate_conll_file(fout) + + ## calc metric y_true = np.array([tag2idx[line.split()[1]] for line in open("temp", 'r').read().splitlines() if len(line) > 0]) y_pred = np.array([tag2idx[line.split()[2]] for line in open("temp", 'r').read().splitlines() if len(line) > 0])