-
Notifications
You must be signed in to change notification settings - Fork 96
feat(CrossValidationReport): Add threshold averaging for roc plot #1750
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
eb25f59 to
9c669e5
Compare
|
@glemaitre do you have any thoughts on the questions I've included above? Regarding showing how much variation there is, I had something like this in mind: I generally have a large number of ROC curves in the average, so wouldn't want to show a legend for each one |
|
[automated comment] Please update your PR with main, so that the |
b5a971a to
4e4af76
Compare
Updated to use new format for passing data to displays. I think approval might be needed to get the workflows running? |
skore/src/skore/sklearn/_plot/metrics/precision_recall_curve.py
Outdated
Show resolved
Hide resolved
Coverage Report for |
|||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| File | Stmts | Miss | Cover | Missing |
|---|---|---|---|---|
| venv/lib/python3.12/site-packages/skore | ||||
| __init__.py | 23 | 0 | 100% | |
| _config.py | 28 | 0 | 100% | |
| exceptions.py | 4 | 4 | 0% | 4, 15, 19, 23 |
| venv/lib/python3.12/site-packages/skore/project | ||||
| __init__.py | 2 | 0 | 100% | |
| metadata.py | 67 | 0 | 100% | |
| project.py | 43 | 0 | 100% | |
| reports.py | 11 | 0 | 100% | |
| widget.py | 138 | 5 | 96% | 375–377, 447–448 |
| venv/lib/python3.12/site-packages/skore/sklearn | ||||
| __init__.py | 6 | 0 | 100% | |
| _base.py | 169 | 14 | 91% | 45, 58, 126, 129, 182, 185–186, 188–191, 224, 227–228 |
| find_ml_task.py | 61 | 0 | 100% | |
| types.py | 22 | 0 | 100% | |
| venv/lib/python3.12/site-packages/skore/sklearn/_comparison | ||||
| __init__.py | 5 | 0 | 100% | |
| metrics_accessor.py | 203 | 3 | 98% | 170, 334, 1288 |
| report.py | 95 | 0 | 100% | |
| utils.py | 55 | 0 | 100% | |
| venv/lib/python3.12/site-packages/skore/sklearn/_cross_validation | ||||
| __init__.py | 5 | 0 | 100% | |
| metrics_accessor.py | 207 | 1 | 99% | 327 |
| report.py | 118 | 0 | 100% | |
| venv/lib/python3.12/site-packages/skore/sklearn/_estimator | ||||
| __init__.py | 7 | 0 | 100% | |
| feature_importance_accessor.py | 143 | 2 | 98% | 216–217 |
| metrics_accessor.py | 371 | 9 | 97% | 158, 187, 189, 196, 287, 356, 360, 375, 410 |
| report.py | 155 | 0 | 100% | |
| venv/lib/python3.12/site-packages/skore/sklearn/_plot | ||||
| __init__.py | 2 | 0 | 100% | |
| base.py | 5 | 0 | 100% | |
| style.py | 28 | 0 | 100% | |
| utils.py | 136 | 5 | 96% | 51, 75–77, 81 |
| venv/lib/python3.12/site-packages/skore/sklearn/_plot/metrics | ||||
| __init__.py | 5 | 0 | 100% | |
| confusion_matrix.py | 69 | 4 | 94% | 90, 98, 120, 228 |
| precision_recall_curve.py | 230 | 1 | 99% | 716 |
| prediction_error.py | 160 | 0 | 100% | |
| roc_curve.py | 295 | 37 | 87% | 381, 431–433, 435–440, 442, 446, 451–452, 454, 460–462, 464–465, 470, 475, 477, 483–484, 490, 492–493, 495–496, 498, 607, 708, 874, 914, 1052, 1083 |
| venv/lib/python3.12/site-packages/skore/sklearn/train_test_split | ||||
| __init__.py | 0 | 0 | 100% | |
| train_test_split.py | 49 | 0 | 100% | |
| venv/lib/python3.12/site-packages/skore/sklearn/train_test_split/warning | ||||
| __init__.py | 8 | 0 | 100% | |
| high_class_imbalance_too_few_examples_warning.py | 17 | 1 | 94% | 80 |
| high_class_imbalance_warning.py | 18 | 0 | 100% | |
| random_state_unset_warning.py | 10 | 0 | 100% | |
| shuffle_true_warning.py | 10 | 1 | 90% | 46 |
| stratify_is_set_warning.py | 10 | 0 | 100% | |
| time_based_column_warning.py | 21 | 1 | 95% | 73 |
| train_test_split_warning.py | 4 | 0 | 100% | |
| venv/lib/python3.12/site-packages/skore/utils | ||||
| __init__.py | 6 | 2 | 66% | 8, 13 |
| _accessor.py | 52 | 2 | 96% | 67, 108 |
| _environment.py | 27 | 0 | 100% | |
| _fixes.py | 8 | 0 | 100% | |
| _index.py | 5 | 0 | 100% | |
| _logger.py | 22 | 4 | 81% | 15–17, 19 |
| _measure_time.py | 10 | 0 | 100% | |
| _parallel.py | 38 | 3 | 92% | 23, 33, 124 |
| _patch.py | 13 | 5 | 61% | 21, 23–24, 35, 37 |
| _progress_bar.py | 45 | 0 | 100% | |
| _show_versions.py | 33 | 2 | 93% | 65–66 |
| _testing.py | 37 | 0 | 100% | |
| TOTAL | 3311 | 106 | 96% | |
| Tests | Skipped | Failures | Errors | Time |
|---|---|---|---|---|
| 816 | 5 💤 | 0 ❌ | 0 🔥 | 1m 1s ⏱️ |
Co-authored-by: Auguste Baum <[email protected]>
|
It looks like you're still working on this PR; if so, can you set it to draft? |
Co-authored-by: Auguste Baum <[email protected]>
Thanks for taking another look, just changed to draft now I could do with advice on a couple of points please: |
| display = display_class._compute_data_for_display( | ||
| y_true=y_true, | ||
| y_pred=y_pred, | ||
| average=average, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Mypy is still complaining about this line, even after casting:
Unexpected keyword argument "average" for "_compute_data_for_display" of "PrecisionRecallCurveDisplay
Without casting it shows the same error twice. Any suggestions on why?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hm, the cast looks correct to me... maybe reveal_type could help?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Side-note: This is a sign that _get_display should stop existing, although that should be the subject of another PR.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hm, the cast looks correct to me... maybe reveal_type could help?
Thanks, though it looks how I would expect after casting:
Revealed type is "type[skore.sklearn._plot.metrics.roc_curve.RocCurveDisplay]"
| average_roc_curve = self.roc_curve.query(query) | ||
| average_roc_auc = self.roc_auc.query(query)["roc_auc"].item() | ||
|
|
||
| line_kwargs_validated = _validate_style_kwargs({}, {}) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure how best to take style kwargs for the average line. Other lines slice using split_idx, but this is None for the average line
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, it might be time to make roc_curve_kwargs more specific in this case. Something like
Kwargs = dict[str, Any]
oneOrMore[T] = Union[T, list[T]]
class RocCurveKwargs(TypedDict):
splits: oneOrMore[Kwargs]
average: Optional[Kwargs]
Implements the threshold averaging method for ROC curve averaging (see #1702)
Includes simple test for averaging logic
Still todo:
_MetricsAccessoruses the same interface to call both PR and ROC displays.