Skip to content

HATS query parallelization methods #110

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
175 changes: 175 additions & 0 deletions tutorials/parquet-catalog-demos/euclid-hats-query-methods.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,175 @@
---
jupytext:
text_representation:
extension: .md
format_name: myst
format_version: 0.13
jupytext_version: 1.16.1
kernelspec:
display_name: Python 3 (ipykernel)
language: python
name: python3
---

# Euclid Q1 HATS query - parallelization methods

This notebook demonstrates a couple of ways that a query for a quality Euclid Q1 redshift sample could be parallelized.

+++

## Setup

```{code-cell}
# # Uncomment the next line to install dependencies if needed.
%pip install 'lsdb>=0.5.2' 'numpy>=2.0' pyarrow s3fs
```

```{code-cell}
import os
import sys
import lsdb
import dask
import dask.distributed
import pandas as pd
import pyarrow.compute as pc
import pyarrow.dataset
from upath import UPath
```

```{code-cell}
import top

RUN_ID = "euclid-hats-query-methods"
top.tag(run_id=RUN_ID, time="start-run")
```

```{code-cell}
s3_bucket = "nasa-irsa-euclid-q1"
euclid_prefix = "contributed/q1/merged_objects/hats"

euclid_hats_collection_uri = f"s3://{s3_bucket}/{euclid_prefix}" # for lsdb
euclid_parquet_metadata_path = f"{s3_bucket}/{euclid_prefix}/euclid_q1_merged_objects-hats/dataset/_metadata" # for pyarrow

max_magnitude = 24.5
min_flux = 10 ** ((max_magnitude - 23.9) / -2.5)
```

```{code-cell}
# Columns we actually want to load.
OBJECT_ID = "object_id"
PHZ_Z = "phz_phz_median"
columns = [OBJECT_ID, PHZ_Z]
```

## Load with pyarrow

```{code-cell}
%%time
top.tag(run_id=RUN_ID, time="pyarrow")

# Construct filter for quality PHZ redshifts.
phz_filter = (
(pc.field("mer_vis_det") == 1) # No NIR-only objects.
& (pc.field("mer_flux_detection_total") > min_flux) # I < 24.5
& (pc.divide(pc.field("mer_flux_detection_total"), pc.field("mer_fluxerr_detection_total")) > 5) # I band S/N > 5
& ~pc.field("phz_phz_classification").isin([1, 3, 5, 7]) # Exclude objects classified as star.
& (pc.field("mer_spurious_flag") == 0) # MER quality
)

# Load.
dataset = pyarrow.dataset.parquet_dataset(euclid_parquet_metadata_path, partitioning="hive", filesystem=pyarrow.fs.S3FileSystem())
pa_df = dataset.to_table(columns=columns, filter=phz_filter).to_pandas()
# 1 - 1.5 min

print(f"Pyarrow loaded {sys.getsizeof(pa_df) / 1024**3:.3f}G")
```

```{code-cell}
pa_df = pa_df.sort_values(OBJECT_ID, ignore_index=True)
pa_df
```

## Load with pyarrow + dask

The above query probably doesn't need to be parallelized, but we will want to parallelize similar queries in the future.
One option is to have dask workers execute the pyarrow filter and load.
This query has no spatial component and needs to look at all the files, so below is a basic implementation that just distributes the files.
(I wonder if this could be made better by batching the files, to reduce overhead?)

```{code-cell}
%%time
top.tag(run_id=RUN_ID, time="pyarrow + dask")


@dask.delayed
def load_fragment(frag):
table = frag.to_table(filter=phz_filter, columns=columns)
return table.to_pandas()


client = dask.distributed.Client(n_workers=os.cpu_count(), threads_per_worker=2, memory_limit=None)
delayed_dfs = [load_fragment(frag) for frag in dataset.get_fragments()]
padask_df = pd.concat(dask.compute(*delayed_dfs), ignore_index=True)
client.close()
# 40 sec - 1 min (4, 8, and 16 workers)

print(f"Pyarrow + dask loaded {sys.getsizeof(padask_df) / 1024**3:.3f}G")
# Ignoring the warnings about the s3 connection for now.
```

```{code-cell}
padask_df = padask_df.sort_values(OBJECT_ID, ignore_index=True)
padask_df
```

```{code-cell}
# Check for equality.
pa_df.equals(padask_df)
```

## Load with lsdb

```{code-cell}
%%time
top.tag(run_id=RUN_ID, time="lsdb")

# Construct the query equivalent of phz_filter.
query = (
"mer_vis_det == 1"
f" & mer_flux_detection_total > {min_flux}"
" & mer_flux_detection_total / mer_fluxerr_detection_total > 5"
" & phz_phz_classification not in [1,3,5,7]"
" & mer_spurious_flag == 0"
)

# We don't want to load these columns, but we have to in order to use them in the filter.
extra_columns = ["mer_vis_det", "mer_flux_detection_total", "phz_phz_classification", "mer_spurious_flag", "mer_fluxerr_detection_total"]

# Load.
client = dask.distributed.Client(n_workers=os.cpu_count(), threads_per_worker=2, memory_limit=None)
lsdb_catalog = lsdb.read_hats(euclid_hats_collection_uri, columns=columns + extra_columns)
lsdb_df = lsdb_catalog.query(query).compute()
client.close()
# 4.5 - 5.5 min (16 workers)
# 8 - 9 min (8 workers)
# 21 min (4 workers)

print(f"Pyarrow loaded {sys.getsizeof(lsdb_df) / 1024**3:.3f}G")
```

```{code-cell}
lsdb_df = lsdb_df.sort_values(OBJECT_ID, ignore_index=True)
lsdb_df
```

```{code-cell}
# Check for equality.
lsdb_df[PHZ_Z].astype("float32").equals(pa_df[PHZ_Z])

top.tag(run_id=RUN_ID, time="end-run")


tl = top.load_top_output(run_id=RUN_ID, named_pids_only=False)
fig = tl.plot_overview()
fig.savefig(tl.base_dir / "top.png")
```
Loading