Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1,239 changes: 734 additions & 505 deletions docs/pre_executed/demo.ipynb

Large diffs are not rendered by default.

2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,10 @@ dependencies = [
"numpy>=2",
"onnxscript>=0.5.3",
"scipy>=1",
"setuptools<70", # provides pkg_resources, required by tensorboard but missing from its dependencies
"tensorboard>=2",
"torch>=2.8.0,<3",
"shap",
"tqdm>=4",
]

Expand Down
13 changes: 11 additions & 2 deletions src/uncle_val/datasets/dp1.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import numpy as np
import pandas as pd
import pyarrow as pa
from nested_pandas import NestedFrame
from upath import UPath

from uncle_val.variability_detectors import get_combined_variability_detector
Expand Down Expand Up @@ -83,10 +84,10 @@ def _split_light_curves_by_band(
single_band["band"] = band

single_band["object_mag"] = single_band[f"{band}_psfMag"]
single_band = single_band.drop(columns=[f"{band}_psfMag" for band in LSDB_BANDS])
single_band = single_band.drop(columns=[f"{b}_psfMag" for b in bands])

single_band["extendedness"] = single_band[f"{band}_extendedness"]
single_band = single_band.drop(columns=[f"{band}_extendedness" for band in LSDB_BANDS])
single_band = single_band.drop(columns=[f"{b}_extendedness" for b in bands])

single_band_dfs.append(single_band)

Expand Down Expand Up @@ -357,6 +358,7 @@ def dp1_catalog_multi_band(
phot: Literal["PSF"],
mode: Literal["forced"],
variability_detectors: Sequence[Callable] | Literal["all"] = "all",
pre_filter_partition: Callable[[NestedFrame], NestedFrame] | None = None,
):
"""Rubin DP1 LSDB catalog, bands are one-hot encoded.

Expand Down Expand Up @@ -397,6 +399,10 @@ def dp1_catalog_multi_band(
Which variability detectors are to pass to
`get_combined_variability_detector()`, default passing `None` which means
using all of them.
pre_filter_partition : callable or None
Optional function applied to each catalog partition before any other
processing. Receives a ``NestedFrame`` and returns a filtered
``NestedFrame``.

Returns
-------
Expand All @@ -421,6 +427,9 @@ def dp1_catalog_multi_band(
read_visit_cols=True,
)

if pre_filter_partition is not None:
catalog = catalog.map_partitions(pre_filter_partition)

if variability_detectors == "all":
variability_detectors = None
var_detector = get_combined_variability_detector(variability_detectors)
Expand Down
77 changes: 77 additions & 0 deletions src/uncle_val/datasets/materialized.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
import shutil
from pathlib import Path

import torch
from torch.utils.data import DataLoader, Dataset

from uncle_val.learning.lsdb_dataset import LSDBIterableDataset


class MaterializedDataset(Dataset):
"""On-disk dataset of serialized tensors.

Parameters
----------
data_dir : Path
Directory with *.pt files
device : torch.device
PyTorch device to load data into.
"""

def __init__(self, *, data_dir: Path, device: torch.device):
self.paths = sorted(data_dir.glob("*.pt"))
self.device = device

def __len__(self):
return len(self.paths)

def __getitem__(self, index):
return torch.load(self.paths[index], map_location=self.device)


class MaterializedDataLoaderContext:
"""Context manager that materializes an LSDB dataset to disk and serves it as a DataLoader.

Tensors are saved to a temporary directory on context entry and
deleted on context exit.

Parameters
----------
input_dataset : LSDBIterableDataset
Dataset which yields data to materialize.
tmp_dir : Path or str
Temporary directory to save data into. It will be created on
the context entrance, and filled with the data.
It will be deleted on the context exit.
"""

def __init__(self, input_dataset: LSDBIterableDataset, tmp_dir: Path | str):
self.input_dataset = input_dataset
self.tmp_dir = Path(tmp_dir)

def _serialize_data(self):
n_chunks = 0
for chunk in self.input_dataset:
if n_chunks >= 1e6:
raise RuntimeError("Number of chunks is more than a million!!!")
torch.save(chunk, self.tmp_dir / f"chunk_{n_chunks:05d}.pt")
n_chunks += 1
if n_chunks == 0:
raise RuntimeError(
"Dataset yielded no batches. "
"Check that your catalog has enough objects in the requested hash range "
"after all filters, and that batch_size is not larger than the total "
"number of qualifying light curves."
)

def __enter__(self) -> DataLoader:
self.tmp_dir.mkdir(parents=True, exist_ok=True)
self._serialize_data()
dataset = MaterializedDataset(data_dir=self.tmp_dir, device=self.input_dataset.device)
return DataLoader(
dataset,
shuffle=False,
)

def __exit__(self, exc_type, exc_val, exc_tb):
shutil.rmtree(self.tmp_dir)
135 changes: 135 additions & 0 deletions src/uncle_val/learning/feature_importance.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
from collections.abc import Iterable
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import shap
import torch


class _FlatWrapper(torch.nn.Module):
"""Wraps an UncleModel to accept flat ``(N, n_features)`` input.

Treats N observations as a single light curve ``(1, N, n_features)``
so that each observation is processed independently, and returns the
per-observation uncertainty factor ``u`` as a 2-D tensor ``(N, 1)``.
"""

def __init__(self, model: torch.nn.Module) -> None:
super().__init__()
self.model = model

def forward(self, flat_inputs: torch.Tensor) -> torch.Tensor:
"""Forward pass returning u per observation."""
output = self.model(flat_inputs.unsqueeze(0)) # (1, N, d_output)
return output[0, :, :1] # (N, 1) — shap requires 2D output


def compute_shap_values(
*,
model_path: str | Path,
data_loader: Iterable,
device: torch.device,
n_background: int = 500,
n_test: int = 2000,
) -> tuple[np.ndarray, np.ndarray]:
"""Compute SHAP values for the predicted uncertainty factor ``u``.

Uses :class:`shap.GradientExplainer` with a flat ``(N, n_features)``
wrapper around the model. Background and test observations are drawn
from the first batches of *data_loader*.

Parameters
----------
model_path : str or Path
Path to the saved model checkpoint.
data_loader : iterable
Yields batches of shape ``(batch_lc, n_src, n_features)`` or
``(1, batch_lc, n_src, n_features)`` (e.g. a raw
``LSDBIterableDataset`` or a ``DataLoader`` wrapping one).
device : torch.device
Device for model and data.
n_background : int
Number of background observations for the SHAP baseline.
n_test : int
Number of observations to explain.

Returns
-------
shap_values : np.ndarray, shape ``(n_test, n_features)``
SHAP values for each observation and feature.
feature_data : np.ndarray, shape ``(n_test, n_features)``
Raw feature values corresponding to *shap_values*, used for
colouring the beeswarm plot.
"""
model = torch.load(model_path, weights_only=False, map_location=device)
model.eval()

# Collect flat observations from batches
obs: list[torch.Tensor] = []
n_needed = n_background + n_test
for batch in data_loader:
flat = batch.squeeze(0).reshape(-1, batch.shape[-1]) # (batch_lc*n_src, n_features)
obs.append(flat)
if sum(t.shape[0] for t in obs) >= n_needed:
break
all_obs = torch.cat(obs, dim=0)[:n_needed]

background = all_obs[:n_background]
test_data = all_obs[n_background:n_needed]

wrapped = _FlatWrapper(model)
explainer = shap.GradientExplainer(wrapped, background)
shap_values = np.array(explainer.shap_values(test_data))
# GradientExplainer returns (n_test, n_features, n_outputs); drop the output dim
if shap_values.ndim == 3:
shap_values = shap_values[..., 0]

return shap_values, test_data.cpu().detach().numpy()


def plot_shap_summary(
shap_values: np.ndarray,
feature_data: np.ndarray,
input_names: list[str],
*,
output_path: str | Path | None = None,
title: str = "SHAP Feature Importance",
) -> plt.Figure:
"""Plot a SHAP beeswarm summary for the predicted uncertainty factor ``u``.

Each dot represents one observation, positioned by its SHAP value (impact
on ``u``) and coloured by the raw feature value (red = high, blue = low).

Parameters
----------
shap_values : np.ndarray, shape ``(n_samples, n_features)``
SHAP values as returned by :func:`compute_shap_values`.
feature_data : np.ndarray, shape ``(n_samples, n_features)``
Raw feature values corresponding to *shap_values*.
input_names : list of str
Feature names in the order of the last dimension.
output_path : str, Path, or None
If given, save the figure to this path.
title : str
Figure title.

Returns
-------
matplotlib.figure.Figure
The created figure.
"""
plt.close("all")
explanation = shap.Explanation(
values=shap_values,
data=feature_data,
feature_names=input_names,
)
shap.plots.beeswarm(explanation, show=False, max_display=len(input_names))
fig = plt.gcf()
fig.suptitle(title, y=1.01)

if output_path is not None:
fig.savefig(output_path, bbox_inches="tight")

return fig
32 changes: 25 additions & 7 deletions src/uncle_val/learning/lsdb_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,17 +22,23 @@ def _reduce_all_columns_wrapper(*args, columns=None, udf, **kwargs):


def _process_lc(
row: dict[str, object], *, n_src: int, lc_col: str, length_col: str, rng: np.random.Generator
row: dict[str, object],
*,
n_src: int,
subsample_src: bool,
lc_col: str,
length_col: str,
rng: np.random.Generator,
) -> dict[str, np.ndarray]:
lc_length = row.pop(length_col)
idx = rng.choice(lc_length, size=n_src, replace=False)
idx = rng.choice(lc_length, size=n_src, replace=False) if subsample_src else np.arange(lc_length)

result: dict[str, np.ndarray] = {}
for col, value in row.items():
if col.startswith(f"{lc_col}."):
result[col] = value[idx]
else:
result[f"{lc_col}.{col}"] = np.full(n_src, value)
result[f"{lc_col}.{col}"] = np.full(len(idx), value)
return result


Expand All @@ -41,6 +47,7 @@ def _process_partition(
pixel: HealpixPixel,
*,
n_src: int,
subsample_src: bool,
lc_col: str,
id_col: str,
hash_range: tuple[int, int] | None,
Expand Down Expand Up @@ -77,6 +84,7 @@ def _process_partition(
columns=columns,
udf=_process_lc,
n_src=n_src,
subsample_src=subsample_src,
lc_col=lc_col,
length_col=length_col,
rng=rng,
Expand All @@ -92,6 +100,7 @@ def lsdb_nested_series_data_generator(
id_col: str = "id",
client: dask.distributed.Client | None,
n_src: int,
subsample_src: bool = True,
partitions_per_chunk: int | None,
hash_range: tuple[int, int] | None = None,
loop: bool = False,
Expand All @@ -101,8 +110,10 @@ def lsdb_nested_series_data_generator(

The data is pre-fetched on the background, 'n_workers' number
of partitions per time (derived from `client` object).
It filters out light curves with less than `n_src` observations,
and selects `n_src` random observations per light curve.
Filters out light curves with fewer than `n_src` observations.
If `subsample_src` is ``True``, selects exactly `n_src` random observations
per light curve. If ``False``, all observations from qualifying light curves
are included.

Parameters
----------
Expand All @@ -118,7 +129,12 @@ def lsdb_nested_series_data_generator(
value. If Dask client is given, the data would be fetched on the
background.
n_src : int
Number of random observations per light curve.
Minimum number of observations required per light curve. Also the
subsample target when `subsample_src` is ``True``.
subsample_src : bool, optional
If ``True`` (default), randomly subsample exactly `n_src` observations
per light curve. If ``False``, include all observations from qualifying
light curves.
partitions_per_chunk : int
Number of `catalog` partitions load in memory simultaneously.
This changes the randomness.
Expand Down Expand Up @@ -151,6 +167,7 @@ def lsdb_nested_series_data_generator(
_process_partition,
include_pixel=True,
n_src=n_src,
subsample_src=subsample_src,
lc_col=lc_col,
id_col=id_col,
hash_range=hash_range,
Expand Down Expand Up @@ -205,7 +222,8 @@ class LSDBIterableDataset(IterableDataset):
Number of batches to yield. If `splits` is used, it will be the size
of the first subset.
n_src : int
Number of random observations per light curve.
Number of random observations per light curve. Light curves with fewer
than `n_src` observations are filtered out.
partitions_per_chunk : int or None
Number of `catalog` partitions per time, if None it is derived
from the number of dask workers associated with `Client` (one if
Expand Down
2 changes: 2 additions & 0 deletions src/uncle_val/pipelines/__init__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
from .dp1_constant_magerr import run_dp1_constant_magerr
from .dp1_feature_importance import run_dp1_feature_importance
from .dp1_linear_flux_err import run_dp1_linear_flux_err
from .dp1_mlp import run_dp1_mlp
from .plotting import make_plots

__all__ = (
"make_plots",
"run_dp1_constant_magerr",
"run_dp1_feature_importance",
"run_dp1_linear_flux_err",
"run_dp1_mlp",
)
Loading
Loading