1212from skore .externals ._pandas_accessors import DirNamesMixin
1313from skore .sklearn ._base import _BaseReport
1414from skore .sklearn ._estimator .report import EstimatorReport
15+ from skore .sklearn .utils import _SCORE_OR_LOSS_INFO
1516from skore .utils ._progress_bar import progress_decorator
1617
1718if TYPE_CHECKING :
@@ -351,18 +352,22 @@ def get_predictions(
351352 for report in self .estimator_reports_
352353 ]
353354
354- def plot_perf_against_time (
355+ def pairwise_plot (
355356 self ,
356- perf_metric : str ,
357+ perf_metric_x : str ,
358+ perf_metric_y : str ,
357359 data_source : Literal ["test" , "train" , "X_y" ] = "test" ,
358- time_metric : Literal [ "fit" , "predict" ] = "predict" ,
360+ pos_label : Optional [ Any ] = None ,
359361 ):
360- """
361- Plot a given performance metric against a time metric.
362+ """Plot a given performance metric against another.
362363
363364 Parameters
364365 ----------
365- perf_metric : str
366+ perf_metric_x : str
367+ The performance metric to plot on the abscissa axis.
368+
369+ perf_metric_y : str
370+ The performance metrics to plot on the ordinates axis.
366371
367372 data_source : {"test", "train", "X_y"}, default="test"
368373 The data source to use.
@@ -371,11 +376,12 @@ def plot_perf_against_time(
371376 - "train" : use the train set provided when creating the report.
372377 - "X_y" : use the provided `X` and `y` to compute the metric.
373378
374- perf_metric : str
375-
376- time_metric: {"fit", "predict"}, default = "predict"
377- The time metric to use in the plot.
378-
379+ pos_label : int, float, bool or str, default=None
380+ The positive class when it comes to binary classification. When
381+ `response_method="predict_proba"`, it will select the column corresponding
382+ to the positive class. When `response_method="decision_function"`, it will
383+ negate the decision function if `pos_label` is different from
384+ `estimator.classes_[1]`.
379385
380386 Returns
381387 -------
@@ -403,23 +409,21 @@ def plot_perf_against_time(
403409 # - turn into display
404410 # - change name to sth like `pairwise_plot`
405411
406- if time_metric == "fit" :
407- x_label = "Fit time"
408- elif time_metric == "predict" :
409- x_label = "Predict time"
412+ x_label = _SCORE_OR_LOSS_INFO [perf_metric_x ].get ("name" , perf_metric_x )
413+ y_label = _SCORE_OR_LOSS_INFO [perf_metric_y ].get ("name" , perf_metric_y )
410414
411- scatter_data = self .metrics .report_metrics ().T .reset_index ()
415+ scatter_data = self .metrics .report_metrics (pos_label = pos_label ).T .reset_index ()
412416 scatter_data .plot (
413417 kind = "scatter" ,
414418 x = x_label ,
415- y = "Brier score" ,
419+ y = y_label ,
416420 title = "Performance vs Time (s)" ,
417421 )
418422
419423 # Add labels to the points with a small offset
420424 text = scatter_data ["Estimator" ]
421425 x = scatter_data [x_label ]
422- y = scatter_data ["Brier score" ]
426+ y = scatter_data [y_label ]
423427 for label , x_coord , y_coord in zip (text , x , y ):
424428 plt .annotate (
425429 label ,
0 commit comments