Skip to content

Commit 4391e91

Browse files
authored
Implement loss decomposition metrics and wrappers for early stopping (#13)
* Enhance .gitignore to exclude Cursor IDE files and add a new example script demonstrating TS-Refinement Early Stopping for XGBoost. Update the splinator module to include new metrics and temperature scaling utilities for improved calibration and loss decomposition. Introduce metric wrappers for compatibility with various ML frameworks, and add comprehensive tests for new functionalities. * Update README and examples to include TS-Refinement metrics and early stopping guidance. Enhance documentation with new references and examples for improved clarity on calibration techniques. Modify metrics module to reflect updated references and improve descriptions. * Revise README for clarity and update examples to reflect new metrics and calibration techniques. Introduce spline_refinement_loss function for enhanced calibration flexibility and update references in the metrics module. Adjust tests to align with new logloss_decomposition naming conventions.
1 parent f0c85a0 commit 4391e91

File tree

9 files changed

+2456
-48
lines changed

9 files changed

+2456
-48
lines changed

.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -164,5 +164,8 @@ cython_debug/
164164
.DS_Store
165165
.idea/
166166

167+
# Cursor IDE
168+
.cursor/
169+
167170
# Local issue drafts
168171
.github/ISSUES/

README.md

Lines changed: 70 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -1,78 +1,101 @@
11
# Splinator 📈
22

3-
**Probablistic Calibration with Regression Splines**
3+
**Probability Calibration for Python**
44

5-
[scikit-learn](https://scikit-learn.org) compatible
5+
A scikit-learn compatible toolkit for measuring and improving probability calibration.
66

7-
[![uv](https://img.shields.io/endpoint?url=https://raw.githubusercontent.com/astral-sh/uv/main/assets/badge/v0.json)](https://github.com/astral-sh/uv)
7+
[![PyPI version](https://img.shields.io/pypi/v/splinator)](https://pypi.org/project/splinator/)
8+
[![Downloads](https://static.pepy.tech/badge/splinator)](https://pepy.tech/project/splinator)
9+
[![Downloads/Month](https://static.pepy.tech/badge/splinator/month)](https://pepy.tech/project/splinator)
810
[![Documentation Status](https://readthedocs.org/projects/splinator/badge/?version=latest)](https://splinator.readthedocs.io/en/latest/)
911
[![Build](https://img.shields.io/github/actions/workflow/status/affirm/splinator/.github/workflows/python-package.yml)](https://github.com/affirm/splinator/actions)
1012

1113
## Installation
1214

13-
`pip install splinator`
15+
```bash
16+
pip install splinator
17+
```
1418

15-
## Algorithm
19+
## What's Inside
1620

17-
Supported models:
21+
| Category | Components |
22+
|----------|------------|
23+
| **Calibrators** | `LinearSplineLogisticRegression` (piecewise), `TemperatureScaling` (single param) |
24+
| **Refinement Metrics** | `spline_refinement_loss`, `ts_refinement_loss` |
25+
| **Decomposition** | `logloss_decomposition`, `brier_decomposition` |
26+
| **Calibration Metrics** | ECE, Spiegelhalter's z |
1827

19-
- Linear Spline Logistic Regression
28+
## Quick Start
2029

21-
Supported metrics:
30+
```python
31+
from splinator import LinearSplineLogisticRegression, TemperatureScaling
2232

23-
- Spiegelhalter’s z statistic
24-
- Expected Calibration Error (ECE)
33+
# Piecewise linear calibration (flexible, monotonic)
34+
spline = LinearSplineLogisticRegression(n_knots=10, monotonicity='increasing')
35+
spline.fit(scores.reshape(-1, 1), y_true)
36+
calibrated = spline.predict_proba(scores.reshape(-1, 1))[:, 1]
2537

26-
\[1\] You can find more information in the [Linear Spline Logistic
27-
Regression](https://github.com/Affirm/splinator/wiki/Linear-Spline-Logistic-Regression).
38+
# Temperature scaling (simple, single parameter)
39+
ts = TemperatureScaling()
40+
ts.fit(probs.reshape(-1, 1), y_true)
41+
calibrated = ts.predict(probs.reshape(-1, 1))
42+
```
2843

29-
\[2\] Additional readings
44+
## Calibration Metrics
3045

31-
- Zhang, Jian, and Yiming Yang. [Probabilistic score estimation with
32-
piecewise logistic
33-
regression](https://pal.sri.com/wp-content/uploads/publications/radar/2004/icml04zhang.pdf).
34-
Proceedings of the twenty-first international conference on Machine
35-
learning. 2004.
36-
- Guo, Chuan, et al. "On calibration of modern neural networks." International conference on machine learning. PMLR, 2017.
46+
```python
47+
from splinator import (
48+
expected_calibration_error,
49+
spiegelhalters_z_statistic,
50+
logloss_decomposition, # Log loss → refinement + calibration
51+
brier_decomposition, # Brier score → refinement + calibration
52+
spline_refinement_loss, # Log loss after piecewise spline
53+
)
3754

55+
# Assess calibration quality
56+
ece = expected_calibration_error(y_true, probs)
57+
z_stat = spiegelhalters_z_statistic(y_true, probs)
3858

39-
## Examples
59+
# Decompose log loss into fixable vs irreducible parts
60+
decomp = logloss_decomposition(y_true, probs)
61+
print(f"Refinement (irreducible): {decomp['refinement_loss']:.4f}")
62+
print(f"Calibration (fixable): {decomp['calibration_loss']:.4f}")
4063

41-
| comparison | notebook |
42-
|------------------------------------------------|--------------------------------------------------------------------------------------------------------------------------------------------------------------------|
43-
| scikit-learn's sigmoid and isotonic regression | [![colab1](https://colab.research.google.com/assets/colab-badge.svg)](https://github.com/Affirm/splinator/blob/main/examples/calibrator_model_comparison.ipynb) |
44-
| pyGAM’s spline model | [![colab2](https://colab.research.google.com/assets/colab-badge.svg)](https://githubtocolab.com/Affirm/splinator/blob/main/examples/spline_model_comparison.ipynb) |
64+
# Refinement using splinator's piecewise calibrator
65+
spline_ref = spline_refinement_loss(y_val, probs, n_knots=5)
66+
```
4567

46-
## Development
68+
## XGBoost / LightGBM Integration
4769

48-
The dependencies are managed by [uv](https://github.com/astral-sh/uv).
70+
Use calibration-aware metrics for early stopping:
4971

50-
```bash
51-
# Install uv (if not already installed)
52-
curl -LsSf https://astral.sh/uv/install.sh | sh
72+
```python
73+
from splinator import ts_refinement_loss
74+
from splinator.metric_wrappers import make_metric_wrapper
5375

54-
# Create virtual environment and install dependencies
55-
uv sync --dev
76+
metric = make_metric_wrapper(ts_refinement_loss, framework='xgboost')
77+
model = xgb.train(params, dtrain, custom_metric=metric, early_stopping_rounds=10, ...)
78+
```
5679

57-
# Run tests
58-
uv run pytest tests -v
80+
## Examples
5981

60-
# Run type checking
61-
uv run mypy src/splinator
62-
```
82+
| Notebook | Description |
83+
|----------|-------------|
84+
| [calibrator_model_comparison](examples/calibrator_model_comparison.ipynb) | Compare with sklearn calibrators |
85+
| [spline_model_comparison](examples/spline_model_comparison.ipynb) | Compare with pyGAM |
86+
| [ts_refinement_xgboost](examples/ts_refinement_xgboost.py) | Early stopping with refinement loss |
6387

64-
## Example Usage
88+
## References
6589

66-
``` python
67-
from splinator.estimators import LinearSplineLogisticRegression
68-
import numpy as np
90+
- Zhang, J. & Yang, Y. (2004). [Probabilistic score estimation with piecewise logistic regression](https://pal.sri.com/wp-content/uploads/publications/radar/2004/icml04zhang.pdf). ICML.
91+
- Guo, C., Pleiss, G., Sun, Y. & Weinberger, K. Q. (2017). [On calibration of modern neural networks](https://arxiv.org/abs/1706.04599). ICML.
92+
- Berta, E., Holzmüller, D., Jordan, M. I. & Bach, F. (2025). [Rethinking Early Stopping: Refine, Then Calibrate](https://arxiv.org/abs/2501.19195). arXiv:2501.19195.
6993

70-
# random synthetic dataset
71-
n_samples = 100
72-
rng = np.random.RandomState(0)
73-
X = rng.normal(loc=100, size=(n_samples, 2))
74-
y = np.random.randint(2, size=n_samples)
94+
See also: [probmetrics](https://github.com/dholzmueller/probmetrics) (PyTorch calibration by the refinement paper authors)
7595

76-
lslr = LinearSplineLogisticRegression(n_knots=10)
77-
lslr.fit(X, y)
96+
## Development
97+
98+
```bash
99+
curl -LsSf https://astral.sh/uv/install.sh | sh # Install uv
100+
uv sync --dev && uv run pytest tests -v # Setup and test
78101
```

0 commit comments

Comments
 (0)