diff --git a/.github/workflows/doc.yml b/.github/workflows/doc.yml index 7aefced..2b3e1ab 100644 --- a/.github/workflows/doc.yml +++ b/.github/workflows/doc.yml @@ -21,8 +21,7 @@ jobs: - name: Install run: | python -m pip install --upgrade pip - pip install . - pip install -r docs/requirements.txt + pip install .[docs] - name: Build documentation run: | make --directory=docs html diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index e124550..18e22b6 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -17,7 +17,7 @@ jobs: strategy: max-parallel: 4 matrix: - python-version: [3.6, 3.7] + python-version: [3.7, 3.8, 3.9, "3.10"] steps: - uses: actions/checkout@v1 @@ -28,7 +28,7 @@ jobs: - name: Install from source run: | python -m pip install --upgrade pip - pip install . + pip install .[tests] - name: Lint with flake8 run: | pip install flake8 @@ -38,5 +38,4 @@ jobs: flake8 ./pyannote --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics - name: Test with pytest run: | - pip install pytest pytest diff --git a/docs/requirements.txt b/docs/requirements.txt deleted file mode 100644 index 7ac2292..0000000 --- a/docs/requirements.txt +++ /dev/null @@ -1,3 +0,0 @@ -Sphinx==2.2.2 -ipython==7.10.1 -sphinx_rtd_theme==0.4.3 diff --git a/pyannote/metrics/__init__.py b/pyannote/metrics/__init__.py index 5a7a952..cdeaf18 100644 --- a/pyannote/metrics/__init__.py +++ b/pyannote/metrics/__init__.py @@ -26,9 +26,8 @@ # AUTHORS # Hervé BREDIN - http://herve.niderb.fr -from .base import f_measure - from ._version import get_versions +from .base import f_measure __version__ = get_versions()["version"] del get_versions diff --git a/pyannote/metrics/base.py b/pyannote/metrics/base.py index b15000b..c85568e 100755 --- a/pyannote/metrics/base.py +++ b/pyannote/metrics/base.py @@ -25,14 +25,17 @@ # AUTHORS # Hervé BREDIN - http://herve.niderb.fr +from typing import List, Union, Optional, Set, Tuple - -import scipy.stats -import pandas as pd import numpy as np +import pandas as pd +import scipy.stats +from pyannote.core import Annotation, Timeline +from pyannote.metrics.types import Details, MetricComponents -class BaseMetric(object): + +class BaseMetric: """ :class:`BaseMetric` is the base class for most pyannote evaluation metrics. @@ -43,23 +46,23 @@ class BaseMetric(object): """ @classmethod - def metric_name(cls): + def metric_name(cls) -> str: raise NotImplementedError( cls.__name__ + " is missing a 'metric_name' class method. " - "It should return the name of the metric as string." + "It should return the name of the metric as string." ) @classmethod - def metric_components(cls): + def metric_components(cls) -> MetricComponents: raise NotImplementedError( cls.__name__ + " is missing a 'metric_components' class method. " - "It should return the list of names of metric components." + "It should return the list of names of metric components." ) def __init__(self, **kwargs): super(BaseMetric, self).__init__() self.metric_name_ = self.__class__.metric_name() - self.components_ = set(self.__class__.metric_components()) + self.components_: Set[str] = set(self.__class__.metric_components()) self.reset() def init_components(self): @@ -67,20 +70,22 @@ def init_components(self): def reset(self): """Reset accumulated components and metric values""" - self.accumulated_ = dict() - self.results_ = list() + self.accumulated_: Details = dict() + self.results_: List = list() for value in self.components_: self.accumulated_[value] = 0.0 - def __get_name(self): - return self.__class__.metric_name() - - name = property(fget=__get_name, doc="Metric name.") + @property + def name(self): + """Metric name.""" + return self.metric_name() # TODO: use joblib/locky to allow parallel processing? # TODO: signature could be something like __call__(self, reference_iterator, hypothesis_iterator, ...) - def __call__(self, reference, hypothesis, detailed=False, uri=None, **kwargs): + def __call__(self, reference: Union[Timeline, Annotation], + hypothesis: Union[Timeline, Annotation], + detailed: bool = False, uri: Optional[str] = None, **kwargs): """Compute metric value and accumulate components Parameters @@ -123,7 +128,7 @@ def __call__(self, reference, hypothesis, detailed=False, uri=None, **kwargs): return components[self.metric_name_] - def report(self, display=False): + def report(self, display: bool = False) -> pd.DataFrame: """Evaluation report Parameters @@ -217,7 +222,7 @@ def __abs__(self): """Compute metric value from accumulated components""" return self.compute_metric(self.accumulated_) - def __getitem__(self, component): + def __getitem__(self, component: str) -> Union[float, Details]: """Get value of accumulated `component`. Parameters @@ -241,7 +246,10 @@ def __iter__(self): for uri, component in self.results_: yield uri, component - def compute_components(self, reference, hypothesis, **kwargs): + def compute_components(self, + reference: Union[Timeline, Annotation], + hypothesis: Union[Timeline, Annotation], + **kwargs) -> Details: """Compute metric components Parameters @@ -260,11 +268,11 @@ def compute_components(self, reference, hypothesis, **kwargs): """ raise NotImplementedError( self.__class__.__name__ + " is missing a 'compute_components' method." - "It should return a dictionary where keys are component names " - "and values are component values." + "It should return a dictionary where keys are component names " + "and values are component values." ) - def compute_metric(self, components): + def compute_metric(self, components: Details): """Compute metric value from computed `components` Parameters @@ -280,11 +288,12 @@ def compute_metric(self, components): """ raise NotImplementedError( self.__class__.__name__ + " is missing a 'compute_metric' method. " - "It should return the actual value of the metric based " - "on the precomputed component dictionary given as input." + "It should return the actual value of the metric based " + "on the precomputed component dictionary given as input." ) - def confidence_interval(self, alpha=0.9): + def confidence_interval(self, alpha: float = 0.9) \ + -> Tuple[float, Tuple[float, float]]: """Compute confidence interval on accumulated metric values Parameters @@ -333,10 +342,10 @@ def metric_name(cls): return PRECISION_NAME @classmethod - def metric_components(cls): + def metric_components(cls) -> MetricComponents: return [PRECISION_RETRIEVED, PRECISION_RELEVANT_RETRIEVED] - def compute_metric(self, components): + def compute_metric(self, components: Details) -> float: """Compute precision from `components`""" numerator = components[PRECISION_RELEVANT_RETRIEVED] denominator = components[PRECISION_RETRIEVED] @@ -371,10 +380,10 @@ def metric_name(cls): return RECALL_NAME @classmethod - def metric_components(cls): + def metric_components(cls) -> MetricComponents: return [RECALL_RELEVANT, RECALL_RELEVANT_RETRIEVED] - def compute_metric(self, components): + def compute_metric(self, components: Details) -> float: """Compute recall from `components`""" numerator = components[RECALL_RELEVANT_RETRIEVED] denominator = components[RECALL_RELEVANT] @@ -387,7 +396,7 @@ def compute_metric(self, components): return numerator / denominator -def f_measure(precision, recall, beta=1.0): +def f_measure(precision: float, recall: float, beta=1.0) -> float: """Compute f-measure f-measure is defined as follows: @@ -398,4 +407,3 @@ def f_measure(precision, recall, beta=1.0): if precision + recall == 0.0: return 0 return (1 + beta * beta) * precision * recall / (beta * beta * precision + recall) - diff --git a/pyannote/metrics/binary_classification.py b/pyannote/metrics/binary_classification.py index aa7c5b1..1c80738 100644 --- a/pyannote/metrics/binary_classification.py +++ b/pyannote/metrics/binary_classification.py @@ -26,15 +26,21 @@ # AUTHORS # Hervé BREDIN - http://herve.niderb.fr -import numpy as np from collections import Counter +from typing import Tuple + +import numpy as np import sklearn.metrics +from numpy.typing import ArrayLike from sklearn.base import BaseEstimator from sklearn.calibration import CalibratedClassifierCV from sklearn.model_selection._split import _CVIterableWrapper +from .types import CalibrationMethod -def det_curve(y_true, scores, distances=False): + +def det_curve(y_true: ArrayLike, scores: ArrayLike, distances: bool = False) \ + -> Tuple[np.ndarray, np.ndarray, np.ndarray, float]: """DET curve Parameters @@ -71,13 +77,16 @@ def det_curve(y_true, scores, distances=False): # estimate equal error rate eer_index = np.where(fpr > fnr)[0][0] - eer = .25 * (fpr[eer_index-1] + fpr[eer_index] + - fnr[eer_index-1] + fnr[eer_index]) + eer = .25 * (fpr[eer_index - 1] + fpr[eer_index] + + fnr[eer_index - 1] + fnr[eer_index]) return fpr, fnr, thresholds, eer -def precision_recall_curve(y_true, scores, distances=False): +def precision_recall_curve(y_true: ArrayLike, + scores: ArrayLike, + distances: bool = False) \ + -> Tuple[np.ndarray, np.ndarray, np.ndarray, float]: """Precision-recall curve Parameters @@ -120,18 +129,18 @@ class _Passthrough(BaseEstimator): """Dummy binary classifier used by score Calibration class""" def __init__(self): - super(_Passthrough, self).__init__() + super().__init__() self.classes_ = np.array([False, True], dtype=np.bool) def fit(self, scores, y_true): return self - def decision_function(self, scores): + def decision_function(self, scores: ArrayLike): """Returns the input scores unchanged""" return scores -class Calibration(object): +class Calibration: """Probability calibration for binary classification tasks Parameters @@ -154,12 +163,12 @@ class Calibration(object): """ - def __init__(self, equal_priors=False, method='isotonic'): - super(Calibration, self).__init__() + def __init__(self, equal_priors: bool = False, + method: CalibrationMethod = 'isotonic'): self.method = method self.equal_priors = equal_priors - def fit(self, scores, y_true): + def fit(self, scores: ArrayLike, y_true: ArrayLike): """Train calibration Parameters @@ -209,7 +218,7 @@ def fit(self, scores, y_true): return self - def transform(self, scores): + def transform(self, scores: ArrayLike): """Calibrate scores into probabilities Parameters diff --git a/pyannote/metrics/cli.py b/pyannote/metrics/cli.py index f0089f3..e04bbde 100644 --- a/pyannote/metrics/cli.py +++ b/pyannote/metrics/cli.py @@ -90,52 +90,44 @@ """ -# command line parsing -from docopt import docopt - -import sys +import functools import json +import sys import warnings -import functools + import numpy as np import pandas as pd -from tabulate import tabulate - -from pyannote.core import Timeline +# command line parsing +from docopt import docopt from pyannote.core import Annotation -from pyannote.database.util import load_rttm - +from pyannote.core import Timeline # evaluation protocols from pyannote.database import get_protocol from pyannote.database.util import get_annotated +from pyannote.database.util import load_rttm +from tabulate import tabulate -from pyannote.metrics.detection import DetectionErrorRate from pyannote.metrics.detection import DetectionAccuracy -from pyannote.metrics.detection import DetectionRecall +from pyannote.metrics.detection import DetectionErrorRate from pyannote.metrics.detection import DetectionPrecision - -from pyannote.metrics.segmentation import SegmentationPurity -from pyannote.metrics.segmentation import SegmentationCoverage -from pyannote.metrics.segmentation import SegmentationPrecision -from pyannote.metrics.segmentation import SegmentationRecall - -from pyannote.metrics.diarization import GreedyDiarizationErrorRate +from pyannote.metrics.detection import DetectionRecall +from pyannote.metrics.diarization import DiarizationCoverage from pyannote.metrics.diarization import DiarizationErrorRate from pyannote.metrics.diarization import DiarizationPurity -from pyannote.metrics.diarization import DiarizationCoverage - +from pyannote.metrics.diarization import GreedyDiarizationErrorRate from pyannote.metrics.identification import IdentificationErrorRate from pyannote.metrics.identification import IdentificationPrecision from pyannote.metrics.identification import IdentificationRecall - +from pyannote.metrics.segmentation import SegmentationCoverage +from pyannote.metrics.segmentation import SegmentationPrecision +from pyannote.metrics.segmentation import SegmentationPurity +from pyannote.metrics.segmentation import SegmentationRecall from pyannote.metrics.spotting import LowLatencySpeakerSpotting showwarning_orig = warnings.showwarning def showwarning(message, category, *args, **kwargs): - import sys - print(category.__name__ + ":", str(message)) diff --git a/pyannote/metrics/detection.py b/pyannote/metrics/detection.py index 4211823..3b1929b 100755 --- a/pyannote/metrics/detection.py +++ b/pyannote/metrics/detection.py @@ -26,8 +26,12 @@ # AUTHORS # Hervé BREDIN - http://herve.niderb.fr # Marvin LAVECHIN +from typing import Optional, Tuple + +from pyannote.core import Annotation, Timeline from .base import BaseMetric, f_measure +from .types import Details, MetricComponents from .utils import UEMSupportMixin DER_NAME = 'detection error rate' @@ -61,19 +65,23 @@ class DetectionErrorRate(UEMSupportMixin, BaseMetric): """ @classmethod - def metric_name(cls): + def metric_name(cls) -> str: return DER_NAME @classmethod - def metric_components(cls): + def metric_components(cls) -> MetricComponents: return [DER_TOTAL, DER_FALSE_ALARM, DER_MISS] - def __init__(self, collar=0.0, skip_overlap=False, **kwargs): + def __init__(self, collar: float = 0.0, skip_overlap: bool = False, **kwargs): super(DetectionErrorRate, self).__init__(**kwargs) self.collar = collar self.skip_overlap = skip_overlap - def compute_components(self, reference, hypothesis, uem=None, **kwargs): + def compute_components(self, + reference: Annotation, + hypothesis: Annotation, + uem: Optional[Timeline] = None, + **kwargs) -> Details: reference, hypothesis, uem = self.uemify( reference, hypothesis, uem=uem, @@ -101,7 +109,7 @@ def compute_components(self, reference, hypothesis, uem=None, **kwargs): return detail - def compute_metric(self, detail): + def compute_metric(self, detail: Details) -> float: error = 1. * (detail[DER_FALSE_ALARM] + detail[DER_MISS]) total = 1. * detail[DER_TOTAL] if total == 0.: @@ -149,11 +157,15 @@ def metric_name(cls): return ACCURACY_NAME @classmethod - def metric_components(cls): + def metric_components(cls) -> MetricComponents: return [ACCURACY_TRUE_POSITIVE, ACCURACY_TRUE_NEGATIVE, ACCURACY_FALSE_POSITIVE, ACCURACY_FALSE_NEGATIVE] - def compute_components(self, reference, hypothesis, uem=None, **kwargs): + def compute_components(self, + reference: Annotation, + hypothesis: Annotation, + uem: Optional[Timeline] = None, + **kwargs) -> Details: reference, hypothesis, uem = self.uemify( reference, hypothesis, uem=uem, @@ -190,7 +202,7 @@ def compute_components(self, reference, hypothesis, uem=None, **kwargs): return detail - def compute_metric(self, detail): + def compute_metric(self, detail: Details) -> float: numerator = 1. * (detail[ACCURACY_TRUE_NEGATIVE] + detail[ACCURACY_TRUE_POSITIVE]) denominator = 1. * (detail[ACCURACY_TRUE_NEGATIVE] + @@ -237,10 +249,14 @@ def metric_name(cls): return PRECISION_NAME @classmethod - def metric_components(cls): + def metric_components(cls) -> MetricComponents: return [PRECISION_RETRIEVED, PRECISION_RELEVANT_RETRIEVED] - def compute_components(self, reference, hypothesis, uem=None, **kwargs): + def compute_components(self, + reference: Annotation, + hypothesis: Annotation, + uem: Optional[Timeline] = None, + **kwargs) -> Details: reference, hypothesis, uem = self.uemify( reference, hypothesis, uem=uem, @@ -266,7 +282,7 @@ def compute_components(self, reference, hypothesis, uem=None, **kwargs): return detail - def compute_metric(self, detail): + def compute_metric(self, detail: Details) -> float: relevant_retrieved = 1. * detail[PRECISION_RELEVANT_RETRIEVED] retrieved = 1. * detail[PRECISION_RETRIEVED] if retrieved == 0.: @@ -308,10 +324,14 @@ def metric_name(cls): return RECALL_NAME @classmethod - def metric_components(cls): + def metric_components(cls) -> MetricComponents: return [RECALL_RELEVANT, RECALL_RELEVANT_RETRIEVED] - def compute_components(self, reference, hypothesis, uem=None, **kwargs): + def compute_components(self, + reference: Annotation, + hypothesis: Annotation, + uem: Optional[Timeline] = None, + **kwargs) -> Details: reference, hypothesis, uem = self.uemify( reference, hypothesis, uem=uem, @@ -337,7 +357,7 @@ def compute_components(self, reference, hypothesis, uem=None, **kwargs): return detail - def compute_metric(self, detail): + def compute_metric(self, detail: Details) -> float: relevant_retrieved = 1. * detail[RECALL_RELEVANT_RETRIEVED] relevant = 1. * detail[RECALL_RELEVANT] if relevant == 0.: @@ -387,14 +407,18 @@ def metric_name(cls): def metric_components(cls): return [DFS_PRECISION_RETRIEVED, DFS_RECALL_RELEVANT, DFS_RELEVANT_RETRIEVED] - def __init__(self, collar=0.0, skip_overlap=False, - beta=1., **kwargs): + def __init__(self, collar: float = 0.0, skip_overlap: bool = False, + beta: float = 1., **kwargs): super(DetectionPrecisionRecallFMeasure, self).__init__(**kwargs) self.collar = collar self.skip_overlap = skip_overlap self.beta = beta - def compute_components(self, reference, hypothesis, uem=None, **kwargs): + def compute_components(self, + reference: Annotation, + hypothesis: Annotation, + uem: Optional[Timeline] = None, + **kwargs) -> Details: reference, hypothesis, uem = self.uemify( reference, hypothesis, uem=uem, @@ -428,11 +452,12 @@ def compute_components(self, reference, hypothesis, uem=None, **kwargs): return detail - def compute_metric(self, detail): + def compute_metric(self, detail: Details) -> float: _, _, value = self.compute_metrics(detail=detail) return value - def compute_metrics(self, detail=None): + def compute_metrics(self, detail: Optional[Details] = None) \ + -> Tuple[float, float, float]: detail = self.accumulated_ if detail is None else detail precision_retrieved = detail[DFS_PRECISION_RETRIEVED] @@ -458,10 +483,11 @@ def compute_metrics(self, detail=None): DCF_NAME = 'detection cost function' -DCF_POS_TOTAL = 'positive class total' # Total duration of positive class. -DCF_NEG_TOTAL = 'negative class total' # Total duration of negative class. -DCF_FALSE_ALARM = 'false alarm' # Total duration of false alarms. -DCF_MISS = 'miss' # Total duration of misses. +DCF_POS_TOTAL = 'positive class total' # Total duration of positive class. +DCF_NEG_TOTAL = 'negative class total' # Total duration of negative class. +DCF_FALSE_ALARM = 'false alarm' # Total duration of false alarms. +DCF_MISS = 'miss' # Total duration of misses. + class DetectionCostFunction(UEMSupportMixin, BaseMetric): """Detection cost function. @@ -503,6 +529,7 @@ class DetectionCostFunction(UEMSupportMixin, BaseMetric): ---------- "OpenSAT19 Evaluation Plan v2." https://www.nist.gov/system/files/documents/2018/11/05/opensat19_evaluation_plan_v2_11-5-18.pdf """ + def __init__(self, collar=0.0, skip_overlap=False, fa_weight=0.25, miss_weight=0.75, **kwargs): super(DetectionCostFunction, self).__init__(**kwargs) @@ -516,10 +543,15 @@ def metric_name(cls): return DCF_NAME @classmethod - def metric_components(cls): + def metric_components(cls) -> MetricComponents: return [DCF_POS_TOTAL, DCF_NEG_TOTAL, DCF_FALSE_ALARM, DCF_MISS] - def compute_components(self, reference, hypothesis, uem=None, **kwargs): + def compute_components(self, + reference: Annotation, + hypothesis: Annotation, + uem: Optional[Timeline] = None, + **kwargs) -> Details: + reference, hypothesis, uem = self.uemify( reference, hypothesis, uem=uem, collar=self.collar, skip_overlap=self.skip_overlap, @@ -548,20 +580,20 @@ def compute_components(self, reference, hypothesis, uem=None, **kwargs): fa_dur += (r_ & h).duration components = { - DCF_POS_TOTAL : pos_dur, - DCF_NEG_TOTAL : neg_dur, - DCF_MISS : miss_dur, - DCF_FALSE_ALARM : fa_dur} + DCF_POS_TOTAL: pos_dur, + DCF_NEG_TOTAL: neg_dur, + DCF_MISS: miss_dur, + DCF_FALSE_ALARM: fa_dur} return components - def compute_metric(self, components): + def compute_metric(self, components: Details) -> float: def _compute_rate(num, denom): if denom == 0.0: if num == 0.0: return 0.0 return 1.0 - return num/denom + return num / denom # Compute false alarm rate. neg_dur = components[DCF_NEG_TOTAL] @@ -573,4 +605,4 @@ def _compute_rate(num, denom): miss_dur = components[DCF_MISS] miss_rate = _compute_rate(miss_dur, pos_dur) - return self.fa_weight*fa_rate + self.miss_weight*miss_rate + return self.fa_weight * fa_rate + self.miss_weight * miss_rate diff --git a/pyannote/metrics/diarization.py b/pyannote/metrics/diarization.py index 512370c..8bc583d 100755 --- a/pyannote/metrics/diarization.py +++ b/pyannote/metrics/diarization.py @@ -27,16 +27,23 @@ # Hervé BREDIN - http://herve.niderb.fr """Metrics for diarization""" +from typing import Optional, Dict, TYPE_CHECKING import numpy as np - -from .matcher import HungarianMapper -from .matcher import GreedyMapper +from pyannote.core import Annotation, Timeline +from pyannote.core.utils.types import Label from .base import BaseMetric, f_measure -from .utils import UEMSupportMixin from .identification import IdentificationErrorRate +from .matcher import GreedyMapper +from .matcher import HungarianMapper +from .types import Details, MetricComponents +from .utils import UEMSupportMixin + +if TYPE_CHECKING: + pass +# TODO: can't we put these as class attributes? DER_NAME = 'diarization error rate' @@ -92,15 +99,18 @@ class DiarizationErrorRate(IdentificationErrorRate): """ @classmethod - def metric_name(cls): + def metric_name(cls) -> str: return DER_NAME - def __init__(self, collar=0.0, skip_overlap=False, **kwargs): - super(DiarizationErrorRate, self).__init__( - collar=collar, skip_overlap=skip_overlap, **kwargs) + def __init__(self, collar: float = 0.0, skip_overlap: bool = False, + **kwargs): + super().__init__(collar=collar, skip_overlap=skip_overlap, **kwargs) self.mapper_ = HungarianMapper() - def optimal_mapping(self, reference, hypothesis, uem=None): + def optimal_mapping(self, + reference: Annotation, + hypothesis: Annotation, + uem: Optional[Timeline] = None) -> Dict[Label, Label]: """Optimal label mapping Parameters @@ -126,8 +136,11 @@ def optimal_mapping(self, reference, hypothesis, uem=None): # call hungarian mapper return self.mapper_(hypothesis, reference) - def compute_components(self, reference, hypothesis, uem=None, **kwargs): - + def compute_components(self, + reference: Annotation, + hypothesis: Annotation, + uem: Optional[Timeline] = None, + **kwargs) -> Details: # crop reference and hypothesis to evaluated regions (uem) # remove collars around reference segment boundaries # remove overlap regions (if requested) @@ -151,7 +164,7 @@ def compute_components(self, reference, hypothesis, uem=None, **kwargs): # NOTE that collar is set to 0.0 because 'uemify' has already # been applied (same reason for setting skip_overlap to False) mapped = hypothesis.rename_labels(mapping=mapping) - return super(DiarizationErrorRate, self)\ + return super(DiarizationErrorRate, self) \ .compute_components(reference, mapped, uem=uem, collar=0.0, skip_overlap=False, **kwargs) @@ -211,12 +224,15 @@ class GreedyDiarizationErrorRate(IdentificationErrorRate): def metric_name(cls): return DER_NAME - def __init__(self, collar=0.0, skip_overlap=False, **kwargs): + def __init__(self, collar: float = 0.0, skip_overlap: bool = False, **kwargs): super(GreedyDiarizationErrorRate, self).__init__( collar=collar, skip_overlap=skip_overlap, **kwargs) self.mapper_ = GreedyMapper() - def greedy_mapping(self, reference, hypothesis, uem=None): + def greedy_mapping(self, + reference: Annotation, + hypothesis: Annotation, + uem: Optional[Timeline] = None) -> Dict[Label, Label]: """Greedy label mapping Parameters @@ -236,8 +252,11 @@ def greedy_mapping(self, reference, hypothesis, uem=None): reference, hypothesis = self.uemify(reference, hypothesis, uem=uem) return self.mapper_(hypothesis, reference) - def compute_components(self, reference, hypothesis, uem=None, **kwargs): - + def compute_components(self, + reference: Annotation, + hypothesis: Annotation, + uem: Optional[Timeline] = None, + **kwargs) -> Details: # crop reference and hypothesis to evaluated regions (uem) # remove collars around reference segment boundaries # remove overlap regions (if requested) @@ -261,7 +280,7 @@ def compute_components(self, reference, hypothesis, uem=None, **kwargs): # NOTE that collar is set to 0.0 because 'uemify' has already # been applied (same reason for setting skip_overlap to False) mapped = hypothesis.rename_labels(mapping=mapping) - return super(GreedyDiarizationErrorRate, self)\ + return super(GreedyDiarizationErrorRate, self) \ .compute_components(reference, mapped, uem=uem, collar=0.0, skip_overlap=False, **kwargs) @@ -339,7 +358,7 @@ def metric_name(cls): return JER_NAME @classmethod - def metric_components(cls): + def metric_components(cls) -> MetricComponents: return [ JER_SPEAKER_COUNT, JER_SPEAKER_ERROR, @@ -350,7 +369,11 @@ def __init__(self, collar=0.0, skip_overlap=False, **kwargs): collar=collar, skip_overlap=skip_overlap, **kwargs) self.mapper_ = HungarianMapper() - def compute_components(self, reference, hypothesis, uem=None, **kwargs): + def compute_components(self, + reference: Annotation, + hypothesis: Annotation, + uem: Optional[Timeline] = None, + **kwargs) -> Details: # crop reference and hypothesis to evaluated regions (uem) # remove collars around reference segment boundaries @@ -412,7 +435,7 @@ def compute_components(self, reference, hypothesis, uem=None, **kwargs): return detail - def compute_metric(self, detail): + def compute_metric(self, detail: Details) -> float: return detail[JER_SPEAKER_ERROR] / detail[JER_SPEAKER_COUNT] @@ -447,14 +470,18 @@ def metric_name(cls): def metric_components(cls): return [PURITY_TOTAL, PURITY_CORRECT] - def __init__(self, collar=0.0, skip_overlap=False, - weighted=True, **kwargs): + def __init__(self, collar: float = 0.0, skip_overlap: bool = False, + weighted: bool = True, **kwargs): super(DiarizationPurity, self).__init__(**kwargs) self.weighted = weighted self.collar = collar self.skip_overlap = skip_overlap - def compute_components(self, reference, hypothesis, uem=None, **kwargs): + def compute_components(self, + reference: Annotation, + hypothesis: Annotation, + uem: Optional[Timeline] = None, + **kwargs) -> Details: detail = self.init_components() @@ -485,7 +512,7 @@ def compute_components(self, reference, hypothesis, uem=None, **kwargs): return detail - def compute_metric(self, detail): + def compute_metric(self, detail: Details) -> float: if detail[PURITY_TOTAL] > 0.: return detail[PURITY_CORRECT] / detail[PURITY_TOTAL] return 1. @@ -516,14 +543,18 @@ class DiarizationCoverage(DiarizationPurity): def metric_name(cls): return COVERAGE_NAME - def __init__(self, collar=0.0, skip_overlap=False, - weighted=True, **kwargs): + def __init__(self, collar: float = 0.0, skip_overlap: bool = False, + weighted: bool = True, **kwargs): super(DiarizationCoverage, self).__init__( collar=collar, skip_overlap=skip_overlap, weighted=weighted, **kwargs) - def compute_components(self, reference, hypothesis, uem=None, **kwargs): - return super(DiarizationCoverage, self)\ + def compute_components(self, + reference: Annotation, + hypothesis: Annotation, + uem: Optional[Timeline] = None, + **kwargs) -> Details: + return super(DiarizationCoverage, self) \ .compute_components(hypothesis, reference, uem=uem, **kwargs) @@ -566,21 +597,25 @@ def metric_name(cls): return PURITY_COVERAGE_NAME @classmethod - def metric_components(cls): + def metric_components(cls) -> MetricComponents: return [PURITY_COVERAGE_LARGEST_CLASS, PURITY_COVERAGE_TOTAL_CLUSTER, PURITY_COVERAGE_LARGEST_CLUSTER, PURITY_COVERAGE_TOTAL_CLASS] - def __init__(self, collar=0.0, skip_overlap=False, - weighted=True, beta=1., **kwargs): + def __init__(self, collar: float = 0.0, skip_overlap: bool = False, + weighted: bool = True, beta: float = 1., **kwargs): super(DiarizationPurityCoverageFMeasure, self).__init__(**kwargs) self.collar = collar self.skip_overlap = skip_overlap self.weighted = weighted self.beta = beta - def compute_components(self, reference, hypothesis, uem=None, **kwargs): + def compute_components(self, + reference: Annotation, + hypothesis: Annotation, + uem: Optional[Timeline] = None, + **kwargs) -> Details: detail = self.init_components() @@ -625,11 +660,11 @@ def compute_components(self, reference, hypothesis, uem=None, **kwargs): # compute purity detail[PURITY_NAME] = \ 1. if detail[PURITY_COVERAGE_TOTAL_CLUSTER] == 0. \ - else detail[PURITY_COVERAGE_LARGEST_CLASS] / detail[PURITY_COVERAGE_TOTAL_CLUSTER] + else detail[PURITY_COVERAGE_LARGEST_CLASS] / detail[PURITY_COVERAGE_TOTAL_CLUSTER] # compute coverage detail[COVERAGE_NAME] = \ 1. if detail[PURITY_COVERAGE_TOTAL_CLASS] == 0. \ - else detail[PURITY_COVERAGE_LARGEST_CLUSTER] / detail[PURITY_COVERAGE_TOTAL_CLASS] + else detail[PURITY_COVERAGE_LARGEST_CLUSTER] / detail[PURITY_COVERAGE_TOTAL_CLASS] return detail @@ -643,11 +678,11 @@ def compute_metrics(self, detail=None): purity = \ 1. if detail[PURITY_COVERAGE_TOTAL_CLUSTER] == 0. \ - else detail[PURITY_COVERAGE_LARGEST_CLASS] / detail[PURITY_COVERAGE_TOTAL_CLUSTER] + else detail[PURITY_COVERAGE_LARGEST_CLASS] / detail[PURITY_COVERAGE_TOTAL_CLUSTER] coverage = \ 1. if detail[PURITY_COVERAGE_TOTAL_CLASS] == 0. \ - else detail[PURITY_COVERAGE_LARGEST_CLUSTER] / detail[PURITY_COVERAGE_TOTAL_CLASS] + else detail[PURITY_COVERAGE_LARGEST_CLUSTER] / detail[PURITY_COVERAGE_TOTAL_CLASS] return purity, coverage, f_measure(purity, coverage, beta=self.beta) @@ -679,12 +714,17 @@ def metric_name(cls): def metric_components(cls): return [HOMOGENEITY_ENTROPY, HOMOGENEITY_CROSS_ENTROPY] - def __init__(self, collar=0.0, skip_overlap=False, **kwargs): + def __init__(self, collar: float = 0.0, skip_overlap: bool = False, + **kwargs): super(DiarizationHomogeneity, self).__init__(**kwargs) self.collar = collar self.skip_overlap = skip_overlap - def compute_components(self, reference, hypothesis, uem=None, **kwargs): + def compute_components(self, + reference: Annotation, + hypothesis: Annotation, + uem: Optional[Timeline] = None, + **kwargs) -> Details: detail = self.init_components() @@ -745,6 +785,10 @@ class DiarizationCompleteness(DiarizationHomogeneity): def metric_name(cls): return COMPLETENESS_NAME - def compute_components(self, reference, hypothesis, uem=None, **kwargs): - return super(DiarizationCompleteness, self)\ + def compute_components(self, + reference: Annotation, + hypothesis: Annotation, + uem: Optional[Timeline] = None, + **kwargs) -> Details: + return super(DiarizationCompleteness, self) \ .compute_components(hypothesis, reference, uem=uem, **kwargs) diff --git a/pyannote/metrics/errors/identification.py b/pyannote/metrics/errors/identification.py index 254c2af..5370b02 100755 --- a/pyannote/metrics/errors/identification.py +++ b/pyannote/metrics/errors/identification.py @@ -26,17 +26,19 @@ # AUTHORS # Hervé BREDIN - http://herve.niderb.fr # Benjamin MAURICE - maurice@limsi.fr +from typing import Optional, TYPE_CHECKING import numpy as np +from pyannote.core import Annotation, Timeline from scipy.optimize import linear_sum_assignment +from ..identification import UEMSupportMixin from ..matcher import LabelMatcher -from pyannote.core import Annotation - from ..matcher import MATCH_CORRECT, MATCH_CONFUSION, \ MATCH_MISSED_DETECTION, MATCH_FALSE_ALARM -from ..identification import UEMSupportMixin +if TYPE_CHECKING: + from xarray import DataArray REFERENCE_TOTAL = 'reference' HYPOTHESIS_TOTAL = 'hypothesis' @@ -47,7 +49,7 @@ BOTH_INCORRECT = 'both_incorrect' -class IdentificationErrorAnalysis(UEMSupportMixin, object): +class IdentificationErrorAnalysis(UEMSupportMixin): """ Parameters @@ -60,14 +62,18 @@ class IdentificationErrorAnalysis(UEMSupportMixin, object): Defaults to False (i.e. keep overlap regions). """ - def __init__(self, collar=0., skip_overlap=False): + def __init__(self, collar: float = 0., skip_overlap: bool = False): - super(IdentificationErrorAnalysis, self).__init__() + super().__init__() self.matcher = LabelMatcher() self.collar = collar self.skip_overlap = skip_overlap - def difference(self, reference, hypothesis, uem=None, uemified=False): + def difference(self, + reference: Annotation, + hypothesis: Annotation, + uem: Optional[Timeline] = None, + uemified: bool = False): """Get error analysis as `Annotation` Labels are (status, reference_label, hypothesis_label) tuples. @@ -133,7 +139,13 @@ def _match_errors(self, before, after): a_type, a_ref, a_hyp = after return (b_ref == a_ref) * (1 + (b_type == a_type) + (b_hyp == a_hyp)) - def regression(self, reference, before, after, uem=None, uemified=False): + # TODO : return type + def regression(self, + reference: Annotation, + before: Annotation, + after: Annotation, + uem: Optional[Timeline] = None, + uemified: bool = False): _, before, errors_before = self.difference( reference, before, uem=uem, uemified=True) @@ -223,7 +235,10 @@ def regression(self, reference, before, after, uem=None, uemified=False): else: return behaviors - def matrix(self, reference, hypothesis, uem=None): + def matrix(self, + reference: Annotation, + hypothesis: Annotation, + uem: Optional[Timeline] = None) -> 'DataArray': reference, hypothesis, errors = self.difference( reference, hypothesis, uem=uem, uemified=True) @@ -252,10 +267,10 @@ def matrix(self, reference, hypothesis, uem=None): # prepend duration columns before the detailed confusion matrix hLabels = [ - REFERENCE_TOTAL, HYPOTHESIS_TOTAL, - MATCH_CORRECT, MATCH_CONFUSION, - MATCH_FALSE_ALARM, MATCH_MISSED_DETECTION - ] + hLabels + REFERENCE_TOTAL, HYPOTHESIS_TOTAL, + MATCH_CORRECT, MATCH_CONFUSION, + MATCH_FALSE_ALARM, MATCH_MISSED_DETECTION + ] + hLabels # initialize empty matrix diff --git a/pyannote/metrics/errors/segmentation.py b/pyannote/metrics/errors/segmentation.py index 1cfb624..34c9c22 100644 --- a/pyannote/metrics/errors/segmentation.py +++ b/pyannote/metrics/errors/segmentation.py @@ -25,17 +25,18 @@ # AUTHORS # Hervé BREDIN - http://herve.niderb.fr - +from typing import Union from pyannote.core import Annotation, Timeline -class SegmentationErrorAnalysis(object): +class SegmentationErrorAnalysis: def __init__(self): - super(SegmentationErrorAnalysis, self).__init__() + super().__init__() - def __call__(self, reference, hypothesis): + def __call__(self, reference: Union[Timeline, Annotation], + hypothesis: Union[Timeline, Annotation]) -> Annotation: if isinstance(reference, Annotation): reference = reference.get_timeline() diff --git a/pyannote/metrics/gamma.py b/pyannote/metrics/gamma.py new file mode 100644 index 0000000..00ec414 --- /dev/null +++ b/pyannote/metrics/gamma.py @@ -0,0 +1,211 @@ +import os +from concurrent.futures.thread import ThreadPoolExecutor +from typing import Union, Optional, Tuple + +import numpy as np +from pyannote.core import Annotation, Timeline +try: + from pygamma_agreement import GammaResults, PositionalSporadicDissimilarity, Continuum, \ + CombinedCategoricalDissimilarity, AbsoluteCategoricalDissimilarity, \ + AbstractDissimilarity + from pygamma_agreement.continuum import _compute_gamma_k_job +except ImportError as err: + raise ImportError("pygamma-agreement cannot be imported, " + "run `pip install pyannote.metrics[gamma]` " + "to fix this") from err + +from sortedcontainers import SortedSet + +from pyannote.metrics.base import BaseMetric +from pyannote.metrics.types import Details, MetricComponents +from pyannote.metrics.utils import UEMSupportMixin + +__all__ = [ + 'GammaDetectionError', + 'GammaIdentificationError', + 'GammaCategorizationError' +] + +GAMMA_DISORDER = "disorder" +GAMMA_CHANCE_DISORDER = "chance disorder" + + +class BaseGammaMetric(UEMSupportMixin, BaseMetric): + @classmethod + def metric_name(cls) -> str: + return "BaseGamma" + + dissim: AbstractDissimilarity + + def __init__(self, + collar: float = 0., + skip_overlap: bool = False, + **kwargs): + super().__init__(**kwargs) + self.collar = collar + self.skip_overlap = skip_overlap + + @classmethod + def metric_components(cls) -> MetricComponents: + return [GAMMA_DISORDER, GAMMA_CHANCE_DISORDER] + + def compute_metric(self, components: Details): + if components[GAMMA_CHANCE_DISORDER] == 0.0: + return 0. + else: + return 1 - (components[GAMMA_DISORDER] / components[GAMMA_CHANCE_DISORDER]) + + def compute_gamma(self, + reference: Union[Annotation, Timeline], + hypothesis: Union[Annotation, Timeline], + ) -> Tuple[float, float]: + raise NotImplementedError() + + def compute_components(self, + reference: Union[Annotation, Timeline], + hypothesis: Union[Annotation, Timeline], + uem: Optional[Timeline] = None, + collar: Optional[float] = None, + skip_overlap: Optional[float] = None, + **kwargs) -> Details: + + if collar is None: + collar = self.collar + if skip_overlap is None: + skip_overlap = self.skip_overlap + + reference, hypothesis, uem = self.uemify( + reference, hypothesis, uem=uem, + collar=collar, skip_overlap=skip_overlap, + returns_uem=True) + + observed, expected = self.compute_gamma(reference, hypothesis) + return { + GAMMA_DISORDER: observed, + GAMMA_CHANCE_DISORDER: expected + } + + def report(self, display=False): + df = super().report(display=False) + + # mean of all column's totals instead of the sum + df.loc["TOTAL"] = df.loc["TOTAL"] / (len(df.index) - 1) + + if display: + print( + df.to_string( + index=True, + sparsify=False, + justify="right", + float_format=lambda f: "{0:.2f}".format(f), + ) + ) + + return df + + +class GammaDetectionError(BaseGammaMetric, UEMSupportMixin): + + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.dissim = PositionalSporadicDissimilarity(delta_empty=1.0) + + @classmethod + def metric_name(cls) -> str: + return "GammaDet" + + def compute_gamma(self, + reference: Union[Annotation, Timeline], + hypothesis: Union[Annotation, Timeline], + ) -> Tuple[float, float]: + if isinstance(reference, Annotation): + reference = reference.get_timeline(copy=True) + + if isinstance(hypothesis, Annotation): + hypothesis = hypothesis.get_timeline(copy=True) + + continuum = Continuum() + continuum.add_timeline("reference", reference) + continuum.add_timeline("hypothesis", hypothesis) + gamma_results = continuum.compute_gamma(self.dissim, + precision_level="medium", + ground_truth_annotators=SortedSet(["reference"]), + soft=True) + return gamma_results.observed_disorder, gamma_results.observed_disorder + + +class GammaIdentificationError(BaseGammaMetric, UEMSupportMixin): + + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.dissim = CombinedCategoricalDissimilarity( + pos_dissim=PositionalSporadicDissimilarity(delta_empty=1.0), + cat_dissim=AbsoluteCategoricalDissimilarity(delta_empty=1.0) + ) + + @classmethod + def metric_name(cls) -> str: + return "GammaId" + + def compute_gamma(self, + reference: Annotation, + hypothesis: Annotation, + ) -> Tuple[float, float]: + assert isinstance(reference, Annotation) + assert isinstance(hypothesis, Annotation) + continuum = Continuum() + continuum.add_annotation("reference", reference) + continuum.add_annotation("hypothesis", hypothesis) + gamma_results = continuum.compute_gamma(self.dissim, + precision_level="medium", + ground_truth_annotators=SortedSet(["reference"]), + soft=True) + return gamma_results.observed_disorder, gamma_results.observed_disorder + + +class GammaCategorizationError(BaseGammaMetric, UEMSupportMixin): + + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.dissim = CombinedCategoricalDissimilarity( + pos_dissim=PositionalSporadicDissimilarity(delta_empty=1.0), + cat_dissim=AbsoluteCategoricalDissimilarity(delta_empty=1.0) + ) + + @classmethod + def metric_name(cls) -> str: + return "GammaCat" + + def _compute_gamma_cat(self, gamma_result: GammaResults) -> Tuple[float, float]: + with ThreadPoolExecutor(max_workers=os.cpu_count()) as p: + + observed_disorder_job = p.submit(_compute_gamma_k_job, + *(gamma_result.dissimilarity, + gamma_result.best_alignment, None)) + + chance_disorders_jobs = [ + p.submit(_compute_gamma_k_job, + *(gamma_result.dissimilarity, alignment, None)) + for alignment in gamma_result.chance_alignments + ] + observed_disorder = observed_disorder_job.result() + if observed_disorder == 0: + return 0, 1 + expected_disorder = float(np.mean(np.array([job_res.result() for job_res in chance_disorders_jobs]))) + if expected_disorder == 0: + return 0, 0 + return observed_disorder, expected_disorder + + def compute_gamma(self, + reference: Annotation, + hypothesis: Annotation, + ) -> Tuple[float, float]: + continuum = Continuum() + continuum.add_annotation("reference", reference) + continuum.add_annotation("hypothesis", hypothesis) + gamma_results = continuum.compute_gamma(self.dissim, + precision_level="medium", + ground_truth_annotators=SortedSet(["reference"]), + soft=True) # TODO: find out if soft or not for this one + + return self._compute_gamma_cat(gamma_results) diff --git a/pyannote/metrics/identification.py b/pyannote/metrics/identification.py index de2d494..0ffddc5 100755 --- a/pyannote/metrics/identification.py +++ b/pyannote/metrics/identification.py @@ -25,6 +25,9 @@ # AUTHORS # Hervé BREDIN - http://herve.niderb.fr +from typing import Optional + +from pyannote.core import Annotation, Timeline from .base import BaseMetric from .base import Precision, PRECISION_RETRIEVED, PRECISION_RELEVANT_RETRIEVED @@ -32,8 +35,10 @@ from .matcher import LabelMatcher, \ MATCH_TOTAL, MATCH_CORRECT, MATCH_CONFUSION, \ MATCH_MISSED_DETECTION, MATCH_FALSE_ALARM +from .types import MetricComponents, Details from .utils import UEMSupportMixin +# TODO: can't we put these as class attributes? IER_TOTAL = MATCH_TOTAL IER_CORRECT = MATCH_CORRECT IER_CONFUSION = MATCH_CONFUSION @@ -68,21 +73,26 @@ class IdentificationErrorRate(UEMSupportMixin, BaseMetric): """ @classmethod - def metric_name(cls): + def metric_name(cls) -> str: return IER_NAME @classmethod - def metric_components(cls): + def metric_components(cls) -> MetricComponents: return [ IER_TOTAL, IER_CORRECT, IER_FALSE_ALARM, IER_MISS, IER_CONFUSION] - def __init__(self, confusion=1., miss=1., false_alarm=1., - collar=0., skip_overlap=False, **kwargs): + def __init__(self, + confusion: float = 1., + miss: float = 1., + false_alarm: float = 1., + collar: float = 0., + skip_overlap: bool = False, + **kwargs): - super(IdentificationErrorRate, self).__init__(**kwargs) + super().__init__(**kwargs) self.matcher_ = LabelMatcher() self.confusion = confusion self.miss = miss @@ -90,8 +100,13 @@ def __init__(self, confusion=1., miss=1., false_alarm=1., self.collar = collar self.skip_overlap = skip_overlap - def compute_components(self, reference, hypothesis, uem=None, - collar=None, skip_overlap=None, **kwargs): + def compute_components(self, + reference: Annotation, + hypothesis: Annotation, + uem: Optional[Timeline] = None, + collar: Optional[float] = None, + skip_overlap: Optional[float] = None, + **kwargs) -> Details: """ Parameters @@ -122,7 +137,6 @@ def compute_components(self, reference, hypothesis, uem=None, # loop on all segments for segment in common_timeline: - # segment duration duration = segment.duration @@ -142,12 +156,12 @@ def compute_components(self, reference, hypothesis, uem=None, return detail - def compute_metric(self, detail): + def compute_metric(self, detail: Details) -> float: numerator = 1. * ( - self.confusion * detail[IER_CONFUSION] + - self.false_alarm * detail[IER_FALSE_ALARM] + - self.miss * detail[IER_MISS] + self.confusion * detail[IER_CONFUSION] + + self.false_alarm * detail[IER_FALSE_ALARM] + + self.miss * detail[IER_MISS] ) denominator = 1. * detail[IER_TOTAL] if denominator == 0.: @@ -172,14 +186,17 @@ class IdentificationPrecision(UEMSupportMixin, Precision): Defaults to False (i.e. keep overlap regions). """ - def __init__(self, collar=0., skip_overlap=False, **kwargs): - super(IdentificationPrecision, self).__init__(**kwargs) + def __init__(self, collar: float = 0., skip_overlap: bool = False, **kwargs): + super().__init__(**kwargs) self.collar = collar self.skip_overlap = skip_overlap self.matcher_ = LabelMatcher() - def compute_components(self, reference, hypothesis, uem=None, **kwargs): - + def compute_components(self, + reference: Annotation, + hypothesis: Annotation, + uem: Optional[Timeline] = None, + **kwargs) -> Details: detail = self.init_components() R, H, common_timeline = self.uemify( @@ -189,7 +206,6 @@ def compute_components(self, reference, hypothesis, uem=None, **kwargs): # loop on all segments for segment in common_timeline: - # segment duration duration = segment.duration @@ -221,14 +237,17 @@ class IdentificationRecall(UEMSupportMixin, Recall): Defaults to False (i.e. keep overlap regions). """ - def __init__(self, collar=0., skip_overlap=False, **kwargs): - super(IdentificationRecall, self).__init__(**kwargs) + def __init__(self, collar: float = 0., skip_overlap: bool = False, **kwargs): + super().__init__(**kwargs) self.collar = collar self.skip_overlap = skip_overlap self.matcher_ = LabelMatcher() - def compute_components(self, reference, hypothesis, uem=None, **kwargs): - + def compute_components(self, + reference: Annotation, + hypothesis: Annotation, + uem: Optional[Timeline] = None, + **kwargs) -> Details: detail = self.init_components() R, H, common_timeline = self.uemify( @@ -238,7 +257,6 @@ def compute_components(self, reference, hypothesis, uem=None, **kwargs): # loop on all segments for segment in common_timeline: - # segment duration duration = segment.duration diff --git a/pyannote/metrics/matcher.py b/pyannote/metrics/matcher.py index d61b957..13e9bb1 100644 --- a/pyannote/metrics/matcher.py +++ b/pyannote/metrics/matcher.py @@ -25,10 +25,15 @@ # AUTHORS # Hervé BREDIN - http://herve.niderb.fr +from typing import Dict, Tuple, Iterable, List, TYPE_CHECKING import numpy as np +from pyannote.core import Annotation from scipy.optimize import linear_sum_assignment +if TYPE_CHECKING: + from pyannote.core.utils.types import Label + MATCH_CORRECT = 'correct' MATCH_CONFUSION = 'confusion' MATCH_MISSED_DETECTION = 'missed detection' @@ -36,16 +41,16 @@ MATCH_TOTAL = 'total' -class LabelMatcher(object): +class LabelMatcher: """ - ID matcher base class. + ID matcher base class mixin. All ID matcher classes must inherit from this class and implement .match() -- ie return True if two IDs match and False otherwise. """ - def match(self, rlabel, hlabel): + def match(self, rlabel: 'Label', hlabel: 'Label') -> bool: """ Parameters ---------- @@ -63,7 +68,9 @@ def match(self, rlabel, hlabel): # Two IDs match if they are equal to each other return rlabel == hlabel - def __call__(self, rlabels, hlabels): + def __call__(self, rlabels: Iterable['Label'], hlabels: Iterable['Label']) \ + -> Tuple[Dict[str, int], + Dict[str, List['Label']]]: """ Parameters @@ -93,6 +100,10 @@ def __call__(self, rlabels, hlabels): MATCH_MISSED_DETECTION: [], MATCH_FALSE_ALARM: [] } + # this is to make sure rlabels and hlabels are lists + # as we will access them later by index + rlabels = list(rlabels) + hlabels = list(hlabels) NR = len(rlabels) NH = len(hlabels) @@ -100,12 +111,7 @@ def __call__(self, rlabels, hlabels): # corner case if N == 0: - return (counts, details) - - # this is to make sure rlabels and hlabels are lists - # as we will access them later by index - rlabels = list(rlabels) - hlabels = list(hlabels) + return counts, details # initialize match matrix # with True if labels match and False otherwise @@ -136,7 +142,7 @@ def __call__(self, rlabels, hlabels): counts[MATCH_CORRECT] += 1 details[MATCH_CORRECT].append((rlabels[r], hlabels[h])) - # refernece and hypothesis do not match + # reference and hypothesis do not match # ==> this is a confusion else: counts[MATCH_CONFUSION] += 1 @@ -145,12 +151,12 @@ def __call__(self, rlabels, hlabels): counts[MATCH_TOTAL] += NR # returns counts and details - return (counts, details) + return counts, details -class HungarianMapper(object): +class HungarianMapper: - def __call__(self, A, B): + def __call__(self, A: Annotation, B: Annotation) -> Dict['Label', 'Label']: mapping = {} cooccurrence = A * B @@ -163,9 +169,9 @@ def __call__(self, A, B): return mapping -class GreedyMapper(object): +class GreedyMapper: - def __call__(self, A, B): + def __call__(self, A: Annotation, B: Annotation) -> Dict['Label', 'Label']: mapping = {} cooccurrence = A * B diff --git a/pyannote/metrics/plot/binary_classification.py b/pyannote/metrics/plot/binary_classification.py index 35c3636..c69bee6 100644 --- a/pyannote/metrics/plot/binary_classification.py +++ b/pyannote/metrics/plot/binary_classification.py @@ -28,18 +28,28 @@ import warnings +from typing import Optional + +import matplotlib import numpy as np +from numpy.typing import ArrayLike + from pyannote.metrics.binary_classification import det_curve from pyannote.metrics.binary_classification import precision_recall_curve -import matplotlib with warnings.catch_warnings(): warnings.simplefilter("ignore") matplotlib.use('Agg') import matplotlib.pyplot as plt -def plot_distributions(y_true, scores, save_to, xlim=None, nbins=100, ymax=3., dpi=150): +def plot_distributions(y_true: ArrayLike, + scores: ArrayLike, + save_to: str, + xlim: Optional[float, float] = None, + nbins: int = 100, + ymax: float = 3., + dpi: int = 150) -> bool: """Scores distributions This function will create (and overwrite) the following files: @@ -75,8 +85,11 @@ def plot_distributions(y_true, scores, save_to, xlim=None, nbins=100, ymax=3., d return True -def plot_det_curve(y_true, scores, save_to, - distances=False, dpi=150): +def plot_det_curve(y_true: ArrayLike, + scores: ArrayLike, + save_to: str, + distances: bool = False, + dpi: int = 150) -> float: """DET curve This function will create (and overwrite) the following files: @@ -129,8 +142,11 @@ def plot_det_curve(y_true, scores, save_to, return eer -def plot_precision_recall_curve(y_true, scores, save_to, - distances=False, dpi=150): +def plot_precision_recall_curve(y_true: ArrayLike, + scores: ArrayLike, + save_to: str, + distances: bool = False, + dpi: int = 150) -> float: """Precision/recall curve This function will create (and overwrite) the following files: diff --git a/pyannote/metrics/segmentation.py b/pyannote/metrics/segmentation.py index 6c244d4..200d6e2 100755 --- a/pyannote/metrics/segmentation.py +++ b/pyannote/metrics/segmentation.py @@ -28,14 +28,17 @@ # Camille Guinaudeau - https://sites.google.com/site/cguinaudeau/ # Mamadou Doumbia # Diego Fustes diego.fustes at toptal.com +from typing import Tuple, Union, Optional import numpy as np from pyannote.core import Segment, Timeline, Annotation from pyannote.core.utils.generators import pairwise from .base import BaseMetric, f_measure +from .types import MetricComponents, Details from .utils import UEMSupportMixin +#  TODO: can't we put these as class attributes? PURITY_NAME = 'segmentation purity' COVERAGE_NAME = 'segmentation coverage' PURITY_COVERAGE_NAME = 'segmentation F[purity|coverage]' @@ -65,11 +68,13 @@ class SegmentationCoverage(BaseMetric): """ - def __init__(self, tolerance=0.500, **kwargs): - super(SegmentationCoverage, self).__init__(**kwargs) + def __init__(self, tolerance: float = 0.500, **kwargs): + super().__init__(**kwargs) self.tolerance = tolerance - def _partition(self, timeline, coverage): + def _partition(self, + timeline: Timeline, + coverage: Timeline) -> Annotation: # boundaries (as set of timestamps) boundaries = set([]) @@ -85,13 +90,15 @@ def _partition(self, timeline, coverage): return partition.crop(coverage, mode='intersection').relabel_tracks() - def _preprocess(self, reference, hypothesis): + def _preprocess(self, reference: Annotation, + hypothesis: Union[Annotation, Timeline]) \ + -> Tuple[Annotation, Annotation]: if not isinstance(reference, Annotation): raise TypeError('reference must be an instance of `Annotation`') if isinstance(hypothesis, Annotation): - hypothesis = hypothesis.get_timeline() + hypothesis: Timeline = hypothesis.get_timeline() # reference where short intra-label gaps are removed filled = Timeline() @@ -112,7 +119,7 @@ def _preprocess(self, reference, hypothesis): return reference_partition, hypothesis_partition - def _process(self, reference, hypothesis): + def _process(self, reference: Annotation, hypothesis: Annotation) -> Details: detail = self.init_components() @@ -128,14 +135,15 @@ def metric_name(cls): return COVERAGE_NAME @classmethod - def metric_components(cls): + def metric_components(cls) -> MetricComponents: return [PTY_CVG_TOTAL, PTY_CVG_INTER] - def compute_components(self, reference, hypothesis, **kwargs): + def compute_components(self, reference: Annotation, + hypothesis: Union[Annotation, Timeline], **kwargs): reference, hypothesis = self._preprocess(reference, hypothesis) return self._process(reference, hypothesis) - def compute_metric(self, detail): + def compute_metric(self, detail: Details) -> float: return detail[PTY_CVG_INTER] / detail[PTY_CVG_TOTAL] @@ -151,10 +159,13 @@ class SegmentationPurity(SegmentationCoverage): """ @classmethod - def metric_name(cls): + def metric_name(cls) -> str: return PURITY_NAME - def compute_components(self, reference, hypothesis, **kwargs): + # TODO : Use type from parent class + def compute_components(self, reference: Annotation, + hypothesis: Union[Annotation, Timeline], + **kwargs) -> Details: reference, hypothesis = self._preprocess(reference, hypothesis) return self._process(hypothesis, reference) @@ -186,7 +197,8 @@ def __init__(self, tolerance=0.500, beta=1, **kwargs): super(SegmentationPurityCoverageFMeasure, self).__init__(tolerance=tolerance, **kwargs) self.beta = beta - def _process(self, reference, hypothesis): + def _process(self, reference: Annotation, + hypothesis: Union[Annotation, Timeline]) -> Details: reference, hypothesis = self._preprocess(reference, hypothesis) detail = self.init_components() @@ -202,32 +214,35 @@ def _process(self, reference, hypothesis): return detail - def compute_components(self, reference, hypothesis, **kwargs): + def compute_components(self, reference: Annotation, + hypothesis: Union[Annotation, Timeline], + **kwargs) -> Details: return self._process(reference, hypothesis) - def compute_metric(self, detail): + def compute_metric(self, detail: Details) -> float: _, _, value = self.compute_metrics(detail=detail) return value - def compute_metrics(self, detail=None): + def compute_metrics(self, detail: Optional[Details] = None) \ + -> Tuple[float, float, float]: detail = self.accumulated_ if detail is None else detail purity = \ 1. if detail[PTY_TOTAL] == 0. \ - else detail[PTY_INTER] / detail[PTY_TOTAL] + else detail[PTY_INTER] / detail[PTY_TOTAL] coverage = \ 1. if detail[CVG_TOTAL] == 0. \ - else detail[CVG_INTER] / detail[CVG_TOTAL] + else detail[CVG_INTER] / detail[CVG_TOTAL] return purity, coverage, f_measure(purity, coverage, beta=self.beta) @classmethod - def metric_name(cls): + def metric_name(cls) -> str: return PURITY_COVERAGE_NAME @classmethod - def metric_components(cls): + def metric_components(cls) -> MetricComponents: return [PTY_TOTAL, PTY_INTER, CVG_TOTAL, CVG_INTER] @@ -268,10 +283,13 @@ def metric_components(cls): def __init__(self, tolerance=0., **kwargs): - super(SegmentationPrecision, self).__init__(**kwargs) + super().__init__(**kwargs) self.tolerance = tolerance - def compute_components(self, reference, hypothesis, **kwargs): + def compute_components(self, + reference: Union[Annotation, Timeline], + hypothesis: Union[Annotation, Timeline], + **kwargs) -> Details: # extract timeline if needed if isinstance(reference, Annotation): @@ -282,7 +300,7 @@ def compute_components(self, reference, hypothesis, **kwargs): detail = self.init_components() # number of matches so far... - nMatches = 0. # make sure it is a float (for later ratio) + n_matches = 0. # make sure it is a float (for later ratio) # number of boundaries in reference and hypothesis N = len(reference) - 1 @@ -297,13 +315,13 @@ def compute_components(self, reference, hypothesis, **kwargs): return detail # reference and hypothesis boundaries - refBoundaries = [segment.end for segment in reference][:-1] - hypBoundaries = [segment.end for segment in hypothesis][:-1] + ref_boundaries = [segment.end for segment in reference][:-1] + hyp_boundaries = [segment.end for segment in hypothesis][:-1] # temporal delta between all pairs of boundaries delta = np.zeros((N, M)) - for r, refBoundary in enumerate(refBoundaries): - for h, hypBoundary in enumerate(hypBoundaries): + for r, refBoundary in enumerate(ref_boundaries): + for h, hypBoundary in enumerate(hyp_boundaries): delta[r, h] = abs(refBoundary - hypBoundary) # make sure boundaries too far apart from each other cannot be matched @@ -317,7 +335,7 @@ def compute_components(self, reference, hypothesis, **kwargs): # while there are still boundaries to match while h < np.inf: # increment match count - nMatches += 1 + n_matches += 1 # find boundaries to match k = np.argmin(delta) @@ -331,10 +349,10 @@ def compute_components(self, reference, hypothesis, **kwargs): # update minimum value in delta h = np.amin(delta) - detail[PR_MATCHES] = nMatches + detail[PR_MATCHES] = n_matches return detail - def compute_metric(self, detail): + def compute_metric(self, detail: Details) -> float: numerator = detail[PR_MATCHES] denominator = detail[PR_BOUNDARIES] @@ -379,6 +397,8 @@ class SegmentationRecall(SegmentationPrecision): def metric_name(cls): return RECALL_NAME - def compute_components(self, reference, hypothesis, **kwargs): + def compute_components(self, reference: Union[Annotation, Timeline], + hypothesis: Union[Annotation, Timeline], + **kwargs) -> Details: return super(SegmentationRecall, self).compute_components( hypothesis, reference) diff --git a/pyannote/metrics/spotting.py b/pyannote/metrics/spotting.py index 318d892..01983b1 100644 --- a/pyannote/metrics/spotting.py +++ b/pyannote/metrics/spotting.py @@ -27,11 +27,22 @@ # Hervé BREDIN - http://herve.niderb.fr import sys +from typing import Union, Iterable, Optional, Tuple, List + import numpy as np +from numpy.typing import ArrayLike +from pyannote.core import Segment, Annotation, SlidingWindowFeature, Timeline + from .base import BaseMetric from .binary_classification import det_curve -from pyannote.core import Segment, Annotation -from pyannote.core import SlidingWindowFeature +from .types import MetricComponents, Details + +SPOTTING_TARGET = "target" +SPOTTING_SPK_LATENCY = 'speaker_latency' +SPOTTING_SPK_SCORE = 'spk_score' +SPOTTING_ABS_LATENCY = 'absolute_latency' +SPOTTING_ABS_SCORE = "abs_score" +SPOTTING_SCORE = "score" class LowLatencySpeakerSpotting(BaseMetric): @@ -63,15 +74,20 @@ class LowLatencySpeakerSpotting(BaseMetric): """ @classmethod - def metric_name(cls): + def metric_name(cls) -> str: return "Low-latency speaker spotting" - @classmethod - def metric_components(cls): - return {'target': 0.} + def metric_components(self) -> MetricComponents: + if self.latencies is None: + return [SPOTTING_TARGET, SPOTTING_ABS_LATENCY, SPOTTING_SPK_SCORE, SPOTTING_SCORE] + else: + return [SPOTTING_TARGET, SPOTTING_SPK_LATENCY, SPOTTING_SPK_SCORE, + SPOTTING_ABS_LATENCY, SPOTTING_ABS_SCORE] - def __init__(self, thresholds=None, latencies=None): - super(LowLatencySpeakerSpotting, self).__init__() + def __init__(self, + thresholds: Optional[ArrayLike] = None, + latencies: Optional[ArrayLike] = None): + super().__init__() if thresholds is None and latencies is None: latencies = [1, 5, 10, 30, 60] @@ -88,10 +104,11 @@ def __init__(self, thresholds=None, latencies=None): self.latencies = latencies - def compute_metric(self, detail): + def compute_metric(self, detail: MetricComponents): return None - def _fixed_latency(self, reference, timestamps, scores): + def _fixed_latency(self, reference: Timeline, + timestamps: List[float], scores: List[float]) -> Details: if not reference: target_trial = False @@ -148,7 +165,9 @@ def _fixed_latency(self, reference, timestamps, scores): 'abs_score': abs_score, } - def _variable_latency(self, reference, timestamps, scores, **kwargs): + def _variable_latency(self, reference: Union[Timeline, Annotation], + timestamps: List[float], scores: List[float], + **kwargs) -> Details: # pre-compute latencies speaker_latency = np.NAN * np.ones((len(timestamps), 1)) @@ -201,7 +220,10 @@ def _variable_latency(self, reference, timestamps, scores, **kwargs): 'score': np.max(scores) } - def compute_components(self, reference, hypothesis, **kwargs): + def compute_components(self, reference: Union[Timeline, Annotation], + hypothesis: Union[SlidingWindowFeature, + Iterable[Tuple[float, float]]], + **kwargs) -> Details: """ Parameters @@ -232,8 +254,12 @@ def speaker_latency(self): if trial['target']] return np.nanmean(latencies, axis=0) - def det_curve(self, cost_miss=100, cost_fa=1, prior_target=0.01, - return_latency=False): + # TODO : figure out return type + def det_curve(self, + cost_miss: float = 100, + cost_fa: float = 1, + prior_target: float = 0.01, + return_latency: bool = False): """DET curve Parameters @@ -272,7 +298,7 @@ def det_curve(self, cost_miss=100, cost_fa=1, prior_target=0.01, fpr, fnr, thresholds, eer = det_curve(y_true, scores, distances=False) fpr, fnr, thresholds = fpr[::-1], fnr[::-1], thresholds[::-1] cdet = cost_miss * fnr * prior_target + \ - cost_fa * fpr * (1. - prior_target) + cost_fa * fpr * (1. - prior_target) if return_latency: # needed to align the thresholds used in the DET curve @@ -284,7 +310,7 @@ def det_curve(self, cost_miss=100, cost_fa=1, prior_target=0.01, fnr = np.take(fnr, indices, mode='clip') cdet = np.take(cdet, indices, mode='clip') return thresholds, fpr, fnr, eer, cdet, \ - self.speaker_latency, self.absolute_latency + self.speaker_latency, self.absolute_latency else: return thresholds, fpr, fnr, eer, cdet @@ -306,7 +332,7 @@ def det_curve(self, cost_miss=100, cost_fa=1, prior_target=0.01, distances=False) fpr, fnr, theta = fpr[::-1], fnr[::-1], theta[::-1] cdet = cost_miss * fnr * prior_target + \ - cost_fa * fpr * (1. - prior_target) + cost_fa * fpr * (1. - prior_target) result[key][latency] = theta, fpr, fnr, eer, cdet return result diff --git a/pyannote/metrics/types.py b/pyannote/metrics/types.py new file mode 100644 index 0000000..a51dc4d --- /dev/null +++ b/pyannote/metrics/types.py @@ -0,0 +1,8 @@ +from typing import Dict, List + +from typing_extensions import Literal + +MetricComponent = str +CalibrationMethod = Literal["isotonic", "sigmoid"] +MetricComponents = List[MetricComponent] +Details = Dict[MetricComponent, float] \ No newline at end of file diff --git a/pyannote/metrics/utils.py b/pyannote/metrics/utils.py index fe4b751..4df3789 100644 --- a/pyannote/metrics/utils.py +++ b/pyannote/metrics/utils.py @@ -27,13 +27,19 @@ # Hervé BREDIN - http://herve.niderb.fr import warnings -from pyannote.core import Timeline, Segment +from typing import Optional, Tuple, Union + +from pyannote.core import Timeline, Segment, Annotation class UEMSupportMixin: """Provides 'uemify' method with optional (à la NIST) collar""" - def extrude(self, uem, reference, collar=0.0, skip_overlap=False): + def extrude(self, + uem: Timeline, + reference: Annotation, + collar: float = 0.0, + skip_overlap: bool = False) -> Timeline: """Extrude reference boundary collars from uem reference |----| |--------------| |-------------| @@ -68,7 +74,6 @@ def extrude(self, uem, reference, collar=0.0, skip_overlap=False): if collar > 0.: # iterate over all segments in reference for segment in reference.itersegments(): - # add collar centered on start time t = segment.start collars.append(Segment(t - .5 * collar, t + .5 * collar)) @@ -90,7 +95,8 @@ def extrude(self, uem, reference, collar=0.0, skip_overlap=False): return Timeline(segments=segments).support().gaps(support=uem) - def common_timeline(self, reference, hypothesis): + def common_timeline(self, reference: Annotation, hypothesis: Annotation) \ + -> Timeline: """Return timeline common to both reference and hypothesis reference |--------| |------------| |---------| |----| @@ -110,7 +116,7 @@ def common_timeline(self, reference, hypothesis): timeline.update(hypothesis.get_timeline(copy=False)) return timeline.segmentation() - def project(self, annotation, timeline): + def project(self, annotation: Annotation, timeline: Timeline) -> Annotation: """Project annotation onto timeline segments reference |__A__| |__B__| @@ -138,8 +144,19 @@ def project(self, annotation, timeline): projection[segment, track] = annotation[segment_, track_] return projection - def uemify(self, reference, hypothesis, uem=None, collar=0., - skip_overlap=False, returns_uem=False, returns_timeline=False): + def uemify(self, + reference: Annotation, + hypothesis: Annotation, + uem: Optional[Timeline] = None, + collar: float = 0., + skip_overlap: bool = False, + returns_uem: bool = False, + returns_timeline: bool = False) \ + -> Union[ + Tuple[Annotation, Annotation], + Tuple[Annotation, Annotation, Timeline], + Tuple[Annotation, Annotation, Timeline, Timeline], + ]: """Crop 'reference' and 'hypothesis' to 'uem' support Parameters @@ -200,9 +217,9 @@ def uemify(self, reference, hypothesis, uem=None, collar=0., result = (reference, hypothesis) if returns_uem: - result += (uem, ) + result += (uem,) if returns_timeline: - result += (timeline, ) + result += (timeline,) return result diff --git a/setup.py b/setup.py index 08eb201..11787ff 100755 --- a/setup.py +++ b/setup.py @@ -39,7 +39,7 @@ # package namespace_packages=["pyannote"], packages=find_packages(), - entry_points={"console_scripts": ["pyannote-metrics=pyannote.metrics.cli:main",],}, + entry_points={"console_scripts": ["pyannote-metrics=pyannote.metrics.cli:main", ], }, install_requires=[ "pyannote.core >= 4.1", "pyannote.database >= 4.0.1", @@ -50,6 +50,7 @@ "tabulate >= 0.7.7", "matplotlib >= 2.0.0", "sympy >= 1.1", + "numpy" ], # versioneer version=versioneer.get_version(), @@ -72,4 +73,17 @@ "Programming Language :: Python :: 3.8", "Topic :: Scientific/Engineering", ], + extras_require={ + "gamma": { + "pygamma-agreement>=0.5.4" + }, + "tests": { + 'pytest' + }, + "docs": [ + "Sphinx==2.2.2", + "ipython==7.10.1", + "sphinx_rtd_theme==0.4.3" + ] + } )