Skip to content

Conversation

@foster999
Copy link

@foster999 foster999 commented May 23, 2025

Implements the threshold averaging method for ROC curve averaging (see #1702)

Includes simple test for averaging logic

Still todo:

  • Handle caching, so that average and none-average plots can be generated from the same report. It currently ignores new arguments and plots cached values. Should we cache averages separately?
  • Include constituent roc curves on average plot, or present a measure of variance/confidence
  • Handle plot kwargs for average ROC line (split_index is None)
  • Implement for PR curve, or add error to say averaging is undefined? Already added the parameter to PR, as _MetricsAccessor uses the same interface to call both PR and ROC displays.
  • Update docstrings with method description and reference

@foster999 foster999 force-pushed the 1702-average-roc branch 3 times, most recently from eb25f59 to 9c669e5 Compare May 23, 2025 11:30
@foster999
Copy link
Author

@glemaitre do you have any thoughts on the questions I've included above?

Regarding showing how much variation there is, I had something like this in mind:
example_average_roc

I generally have a large number of ROC curves in the average, so wouldn't want to show a legend for each one

@thomass-dev
Copy link
Collaborator

thomass-dev commented May 26, 2025

[automated comment] Please update your PR with main, so that the pytest workflow status will be reported.

@foster999
Copy link
Author

[automated comment] Please update your PR with main, so that the pytest workflow status will be reported.

Updated to use new format for passing data to displays. I think approval might be needed to get the workflows running?

@foster999 foster999 marked this pull request as ready for review May 28, 2025 12:38
@github-actions
Copy link
Contributor

github-actions bot commented May 28, 2025

Coverage

Coverage Report for skore/
FileStmtsMissCoverMissing
venv/lib/python3.12/site-packages/skore
   __init__.py230100% 
   _config.py280100% 
   exceptions.py440%4, 15, 19, 23
venv/lib/python3.12/site-packages/skore/project
   __init__.py20100% 
   metadata.py670100% 
   project.py430100% 
   reports.py110100% 
   widget.py138596%375–377, 447–448
venv/lib/python3.12/site-packages/skore/sklearn
   __init__.py60100% 
   _base.py1691491%45, 58, 126, 129, 182, 185–186, 188–191, 224, 227–228
   find_ml_task.py610100% 
   types.py220100% 
venv/lib/python3.12/site-packages/skore/sklearn/_comparison
   __init__.py50100% 
   metrics_accessor.py203398%170, 334, 1288
   report.py950100% 
   utils.py550100% 
venv/lib/python3.12/site-packages/skore/sklearn/_cross_validation
   __init__.py50100% 
   metrics_accessor.py207199%327
   report.py1180100% 
venv/lib/python3.12/site-packages/skore/sklearn/_estimator
   __init__.py70100% 
   feature_importance_accessor.py143298%216–217
   metrics_accessor.py371997%158, 187, 189, 196, 287, 356, 360, 375, 410
   report.py1550100% 
venv/lib/python3.12/site-packages/skore/sklearn/_plot
   __init__.py20100% 
   base.py50100% 
   style.py280100% 
   utils.py136596%51, 75–77, 81
venv/lib/python3.12/site-packages/skore/sklearn/_plot/metrics
   __init__.py50100% 
   confusion_matrix.py69494%90, 98, 120, 228
   precision_recall_curve.py230199%716
   prediction_error.py1600100% 
   roc_curve.py2953787%381, 431–433, 435–440, 442, 446, 451–452, 454, 460–462, 464–465, 470, 475, 477, 483–484, 490, 492–493, 495–496, 498, 607, 708, 874, 914, 1052, 1083
venv/lib/python3.12/site-packages/skore/sklearn/train_test_split
   __init__.py00100% 
   train_test_split.py490100% 
venv/lib/python3.12/site-packages/skore/sklearn/train_test_split/warning
   __init__.py80100% 
   high_class_imbalance_too_few_examples_warning.py17194%80
   high_class_imbalance_warning.py180100% 
   random_state_unset_warning.py100100% 
   shuffle_true_warning.py10190%46
   stratify_is_set_warning.py100100% 
   time_based_column_warning.py21195%73
   train_test_split_warning.py40100% 
venv/lib/python3.12/site-packages/skore/utils
   __init__.py6266%8, 13
   _accessor.py52296%67, 108
   _environment.py270100% 
   _fixes.py80100% 
   _index.py50100% 
   _logger.py22481%15–17, 19
   _measure_time.py100100% 
   _parallel.py38392%23, 33, 124
   _patch.py13561%21, 23–24, 35, 37
   _progress_bar.py450100% 
   _show_versions.py33293%65–66
   _testing.py370100% 
TOTAL331110696% 

Tests Skipped Failures Errors Time
816 5 💤 0 ❌ 0 🔥 1m 1s ⏱️

@github-actions
Copy link
Contributor

Documentation preview @ 06b5895

@auguste-probabl
Copy link
Contributor

It looks like you're still working on this PR; if so, can you set it to draft?

@foster999 foster999 marked this pull request as draft May 30, 2025 08:47
@foster999
Copy link
Author

It looks like you're still working on this PR; if so, can you set it to draft?

Thanks for taking another look, just changed to draft now

I could do with advice on a couple of points please:

display = display_class._compute_data_for_display(
y_true=y_true,
y_pred=y_pred,
average=average,
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mypy is still complaining about this line, even after casting:
Unexpected keyword argument "average" for "_compute_data_for_display" of "PrecisionRecallCurveDisplay

Without casting it shows the same error twice. Any suggestions on why?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hm, the cast looks correct to me... maybe reveal_type could help?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Side-note: This is a sign that _get_display should stop existing, although that should be the subject of another PR.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hm, the cast looks correct to me... maybe reveal_type could help?

Thanks, though it looks how I would expect after casting:
Revealed type is "type[skore.sklearn._plot.metrics.roc_curve.RocCurveDisplay]"

average_roc_curve = self.roc_curve.query(query)
average_roc_auc = self.roc_auc.query(query)["roc_auc"].item()

line_kwargs_validated = _validate_style_kwargs({}, {})
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure how best to take style kwargs for the average line. Other lines slice using split_idx, but this is None for the average line

Copy link
Contributor

@auguste-probabl auguste-probabl May 30, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, it might be time to make roc_curve_kwargs more specific in this case. Something like

Kwargs = dict[str, Any]
oneOrMore[T] = Union[T, list[T]]

class RocCurveKwargs(TypedDict):
	splits: oneOrMore[Kwargs]
	average: Optional[Kwargs]

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants