Skip to content

Commit ff2aab0

Browse files
committed
feature importance
1 parent f6f6e87 commit ff2aab0

File tree

9 files changed

+815
-482
lines changed

9 files changed

+815
-482
lines changed

docs/pre_executed/demo.ipynb

Lines changed: 734 additions & 252 deletions
Large diffs are not rendered by default.

src/uncle_val/datasets/dp1.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import numpy as np
77
import pandas as pd
88
import pyarrow as pa
9+
from nested_pandas import NestedFrame
910
from upath import UPath
1011

1112
from uncle_val.variability_detectors import get_combined_variability_detector
@@ -83,10 +84,10 @@ def _split_light_curves_by_band(
8384
single_band["band"] = band
8485

8586
single_band["object_mag"] = single_band[f"{band}_psfMag"]
86-
single_band = single_band.drop(columns=[f"{band}_psfMag" for band in LSDB_BANDS])
87+
single_band = single_band.drop(columns=[f"{b}_psfMag" for b in bands])
8788

8889
single_band["extendedness"] = single_band[f"{band}_extendedness"]
89-
single_band = single_band.drop(columns=[f"{band}_extendedness" for band in LSDB_BANDS])
90+
single_band = single_band.drop(columns=[f"{b}_extendedness" for b in bands])
9091

9192
single_band_dfs.append(single_band)
9293

@@ -357,6 +358,7 @@ def dp1_catalog_multi_band(
357358
phot: Literal["PSF"],
358359
mode: Literal["forced"],
359360
variability_detectors: Sequence[Callable] | Literal["all"] = "all",
361+
pre_filter_partition: Callable[[NestedFrame], NestedFrame] | None = None,
360362
):
361363
"""Rubin DP1 LSDB catalog, bands are one-hot encoded.
362364
@@ -397,6 +399,10 @@ def dp1_catalog_multi_band(
397399
Which variability detectors are to pass to
398400
`get_combined_variability_detector()`, default passing `None` which means
399401
using all of them.
402+
pre_filter_partition : callable or None
403+
Optional function applied to each catalog partition before any other
404+
processing. Receives a ``NestedFrame`` and returns a filtered
405+
``NestedFrame``.
400406
401407
Returns
402408
-------
@@ -421,6 +427,9 @@ def dp1_catalog_multi_band(
421427
read_visit_cols=True,
422428
)
423429

430+
if pre_filter_partition is not None:
431+
catalog = catalog.map_partitions(pre_filter_partition)
432+
424433
if variability_detectors == "all":
425434
variability_detectors = None
426435
var_detector = get_combined_variability_detector(variability_detectors)

src/uncle_val/learning/lsdb_dataset.py

Lines changed: 25 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -22,17 +22,23 @@ def _reduce_all_columns_wrapper(*args, columns=None, udf, **kwargs):
2222

2323

2424
def _process_lc(
25-
row: dict[str, object], *, n_src: int, lc_col: str, length_col: str, rng: np.random.Generator
25+
row: dict[str, object],
26+
*,
27+
n_src: int,
28+
subsample_src: bool,
29+
lc_col: str,
30+
length_col: str,
31+
rng: np.random.Generator,
2632
) -> dict[str, np.ndarray]:
2733
lc_length = row.pop(length_col)
28-
idx = rng.choice(lc_length, size=n_src, replace=False)
34+
idx = rng.choice(lc_length, size=n_src, replace=False) if subsample_src else np.arange(lc_length)
2935

3036
result: dict[str, np.ndarray] = {}
3137
for col, value in row.items():
3238
if col.startswith(f"{lc_col}."):
3339
result[col] = value[idx]
3440
else:
35-
result[f"{lc_col}.{col}"] = np.full(n_src, value)
41+
result[f"{lc_col}.{col}"] = np.full(len(idx), value)
3642
return result
3743

3844

@@ -41,6 +47,7 @@ def _process_partition(
4147
pixel: HealpixPixel,
4248
*,
4349
n_src: int,
50+
subsample_src: bool,
4451
lc_col: str,
4552
id_col: str,
4653
hash_range: tuple[int, int] | None,
@@ -77,6 +84,7 @@ def _process_partition(
7784
columns=columns,
7885
udf=_process_lc,
7986
n_src=n_src,
87+
subsample_src=subsample_src,
8088
lc_col=lc_col,
8189
length_col=length_col,
8290
rng=rng,
@@ -92,6 +100,7 @@ def lsdb_nested_series_data_generator(
92100
id_col: str = "id",
93101
client: dask.distributed.Client | None,
94102
n_src: int,
103+
subsample_src: bool = True,
95104
partitions_per_chunk: int | None,
96105
hash_range: tuple[int, int] | None = None,
97106
loop: bool = False,
@@ -101,8 +110,10 @@ def lsdb_nested_series_data_generator(
101110
102111
The data is pre-fetched on the background, 'n_workers' number
103112
of partitions per time (derived from `client` object).
104-
It filters out light curves with less than `n_src` observations,
105-
and selects `n_src` random observations per light curve.
113+
Filters out light curves with fewer than `n_src` observations.
114+
If `subsample_src` is ``True``, selects exactly `n_src` random observations
115+
per light curve. If ``False``, all observations from qualifying light curves
116+
are included.
106117
107118
Parameters
108119
----------
@@ -118,7 +129,12 @@ def lsdb_nested_series_data_generator(
118129
value. If Dask client is given, the data would be fetched on the
119130
background.
120131
n_src : int
121-
Number of random observations per light curve.
132+
Minimum number of observations required per light curve. Also the
133+
subsample target when `subsample_src` is ``True``.
134+
subsample_src : bool, optional
135+
If ``True`` (default), randomly subsample exactly `n_src` observations
136+
per light curve. If ``False``, include all observations from qualifying
137+
light curves.
122138
partitions_per_chunk : int
123139
Number of `catalog` partitions load in memory simultaneously.
124140
This changes the randomness.
@@ -151,6 +167,7 @@ def lsdb_nested_series_data_generator(
151167
_process_partition,
152168
include_pixel=True,
153169
n_src=n_src,
170+
subsample_src=subsample_src,
154171
lc_col=lc_col,
155172
id_col=id_col,
156173
hash_range=hash_range,
@@ -205,7 +222,8 @@ class LSDBIterableDataset(IterableDataset):
205222
Number of batches to yield. If `splits` is used, it will be the size
206223
of the first subset.
207224
n_src : int
208-
Number of random observations per light curve.
225+
Number of random observations per light curve. Light curves with fewer
226+
than `n_src` observations are filtered out.
209227
partitions_per_chunk : int or None
210228
Number of `catalog` partitions per time, if None it is derived
211229
from the number of dask workers associated with `Client` (one if
Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,13 @@
11
from .dp1_constant_magerr import run_dp1_constant_magerr
2+
from .dp1_feature_importance import run_dp1_feature_importance
23
from .dp1_linear_flux_err import run_dp1_linear_flux_err
34
from .dp1_mlp import run_dp1_mlp
4-
from .plotting import make_plots, plot_shap_summary
5-
from .validation_set_utils import compute_shap_values
5+
from .plotting import make_plots
66

77
__all__ = (
8-
"compute_shap_values",
98
"make_plots",
10-
"plot_shap_summary",
119
"run_dp1_constant_magerr",
10+
"run_dp1_feature_importance",
1211
"run_dp1_linear_flux_err",
1312
"run_dp1_mlp",
1413
)

src/uncle_val/pipelines/dp1_mlp.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
1+
from collections.abc import Callable, Sequence
12
from pathlib import Path
23

34
import torch
5+
from nested_pandas import NestedFrame
46

57
from uncle_val.datasets.dp1 import dp1_catalog_multi_band
68
from uncle_val.learning.losses import UncleLoss
@@ -26,6 +28,8 @@ def run_dp1_mlp(
2628
log_activations: bool = False,
2729
snapshot_every: int = 128,
2830
device: torch.device | str = "cpu",
31+
bands: Sequence[str] = "ugrizy",
32+
pre_filter_partition: Callable[[NestedFrame], NestedFrame] | None = None,
2933
) -> tuple[Path, list[str]]:
3034
"""Run the training for DP1 with the linear model on fluxes and errors
3135
@@ -68,6 +72,12 @@ def run_dp1_mlp(
6872
Whether to log validation activations with TensorBoard session.
6973
device : torch.device | str
7074
Torch device to use for training.
75+
bands : sequence of str
76+
Bands to include, subset of ``ugrizy``. Defaults to all six bands.
77+
pre_filter_partition : callable or None
78+
Optional function applied to each catalog partition before any other
79+
processing. Receives a ``NestedFrame`` and returns a filtered
80+
``NestedFrame``.
7181
7282
Returns
7383
-------
@@ -76,15 +86,14 @@ def run_dp1_mlp(
7686
list[str]
7787
List of columns to use as model inputs.
7888
"""
79-
bands = "ugrizy"
80-
8189
catalog = dp1_catalog_multi_band(
8290
root=dp1_root,
8391
bands=bands,
8492
obj="science",
8593
img="cal",
8694
phot="PSF",
8795
mode="forced",
96+
pre_filter_partition=pre_filter_partition,
8897
).map_partitions(lambda df: df.drop(columns=["band", "object_mag", "coord_ra", "coord_dec"]))
8998

9099
columns = [

src/uncle_val/pipelines/plotting.py

Lines changed: 0 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -302,56 +302,6 @@ def _plot_magn_vs_uu(
302302
ax.legend()
303303

304304

305-
def plot_shap_summary(
306-
shap_values: np.ndarray,
307-
feature_data: np.ndarray,
308-
input_names: list[str],
309-
*,
310-
output_path: str | Path | None = None,
311-
title: str = "SHAP Feature Importance",
312-
) -> plt.Figure:
313-
"""Plot a SHAP beeswarm summary for the predicted uncertainty factor ``u``.
314-
315-
Each dot represents one observation, positioned by its SHAP value (impact
316-
on ``u``) and coloured by the raw feature value (red = high, blue = low).
317-
318-
Parameters
319-
----------
320-
shap_values : np.ndarray, shape ``(n_samples, n_features)``
321-
SHAP values as returned by
322-
:func:`~uncle_val.pipelines.validation_set_utils.compute_shap_values`.
323-
feature_data : np.ndarray, shape ``(n_samples, n_features)``
324-
Raw feature values corresponding to *shap_values*.
325-
input_names : list of str
326-
Feature names in the order of the last dimension.
327-
output_path : str, Path, or None
328-
If given, save the figure to this path.
329-
title : str
330-
Figure title.
331-
332-
Returns
333-
-------
334-
matplotlib.figure.Figure
335-
The created figure.
336-
"""
337-
import shap
338-
339-
plt.close("all")
340-
explanation = shap.Explanation(
341-
values=shap_values,
342-
data=feature_data,
343-
feature_names=input_names,
344-
)
345-
shap.plots.beeswarm(explanation, show=False, max_display=len(input_names))
346-
fig = plt.gcf()
347-
fig.suptitle(title, y=1.01)
348-
349-
if output_path is not None:
350-
fig.savefig(output_path, bbox_inches="tight")
351-
352-
return fig
353-
354-
355305
def make_plots(
356306
dp1_root: str | Path,
357307
*,

src/uncle_val/pipelines/splits.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
1-
TRAIN_SPLIT = 0.0, 0.75
2-
VALIDATION_SPLIT = 0.75, 0.85
1+
TRAIN_SPLIT = 0.0, 0.6
2+
VALIDATION_SPLIT = 0.6, 0.85
33
TEST_SPLIT = 0.85, 1.0

src/uncle_val/pipelines/training_loop.py

Lines changed: 28 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -12,18 +12,15 @@
1212
from torch.utils.tensorboard import SummaryWriter
1313
from tqdm.auto import tqdm
1414

15+
from uncle_val.datasets.materialized import MaterializedDataLoaderContext
16+
from uncle_val.learning.feature_importance import compute_shap_values, plot_shap_summary
1517
from uncle_val.learning.losses import UncleLoss
1618
from uncle_val.learning.lsdb_dataset import LSDBIterableDataset
1719
from uncle_val.learning.models import BaseUncleModel
1820
from uncle_val.learning.training import train_step
19-
from uncle_val.pipelines.plotting import plot_shap_summary
20-
from uncle_val.pipelines.splits import TRAIN_SPLIT, VALIDATION_SPLIT
21+
from uncle_val.pipelines.splits import TEST_SPLIT, TRAIN_SPLIT, VALIDATION_SPLIT
2122
from uncle_val.pipelines.utils import _launch_tfboard
22-
from uncle_val.pipelines.validation_set_utils import (
23-
ValidationDataLoaderContext,
24-
compute_shap_values,
25-
get_val_stats,
26-
)
23+
from uncle_val.pipelines.validation_set_utils import get_val_stats
2724

2825

2926
def get_val_workers(client: Client, device: torch.device) -> list[object] | None:
@@ -147,7 +144,7 @@ def training_loop(
147144
device=device,
148145
)
149146

150-
with ValidationDataLoaderContext(validation_dataset_lsdb, tmp_validation_dir) as val_dataloader:
147+
with MaterializedDataLoaderContext(validation_dataset_lsdb, tmp_validation_dir) as val_dataloader:
151148
val_stats_future: Future | None = None
152149
mean_val_loss_i = 0
153150

@@ -268,19 +265,33 @@ def snapshot(i):
268265
snapshot(i_train_batch)
269266
snapshot(i_train_batch)
270267

271-
if best_model_path is not None and model.input_names:
268+
if best_model_path is not None and model.input_names:
269+
test_dataset_lsdb = LSDBIterableDataset(
270+
catalog=catalog,
271+
columns=columns,
272+
client=client,
273+
batch_lc=val_batch_size,
274+
n_src=n_src,
275+
partitions_per_chunk=n_workers * 8,
276+
loop=False,
277+
hash_range=TEST_SPLIT,
278+
seed=2,
279+
device=device,
280+
)
281+
tmp_test_dir = output_dir / "test_shap"
282+
with MaterializedDataLoaderContext(test_dataset_lsdb, tmp_test_dir) as test_dataloader:
272283
shap_values, feature_data = compute_shap_values(
273284
model_path=best_model_path,
274-
data_loader=val_dataloader,
285+
data_loader=test_dataloader,
275286
device=device,
276287
)
277-
fig = plot_shap_summary(
278-
shap_values,
279-
feature_data,
280-
input_names=model.input_names,
281-
output_path=output_dir / "feature_importance.png",
282-
)
283-
summary_writer.add_figure("Feature importance", fig)
288+
fig = plot_shap_summary(
289+
shap_values,
290+
feature_data,
291+
input_names=model.input_names,
292+
output_path=output_dir / "feature_importance.png",
293+
)
294+
summary_writer.add_figure("Feature importance", fig)
284295

285296
model.eval()
286297
summary_writer.add_graph(model, train_batch[0])

0 commit comments

Comments
 (0)