Skip to content

Commit a8cc8d4

Browse files
authored
Updates for DP2 (#71)
* Updates for DP2 * Fast fix
1 parent 76af4b7 commit a8cc8d4

File tree

4 files changed

+20
-7
lines changed

4 files changed

+20
-7
lines changed

src/uncle_val/datasets/dp1.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,10 @@ def _read_ccd_visit_table(path, columns):
130130
if columns is None or "seeing" in columns:
131131
seeing = pa.array(df["seeing"])
132132
replace_value = np.nanmean(seeing)
133-
df["seeing"] = pa.compute.replace_with_mask(seeing, pa.compute.is_nan(seeing), replace_value)
133+
nan_seeing_mask = pa.compute.is_nan(seeing)
134+
if isinstance(nan_seeing_mask, pa.ChunkedArray):
135+
nan_seeing_mask = nan_seeing_mask.combine_chunks()
136+
df["seeing"] = pa.compute.replace_with_mask(seeing, nan_seeing_mask, replace_value)
134137
if columns is None or "detectorId" in columns:
135138
detector_cols = _polar_encode_detector(df["detectorId"])
136139
df = df.assign(**detector_cols)

src/uncle_val/datasets/materialized.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,9 +45,10 @@ class MaterializedDataLoaderContext:
4545
It will be deleted on the context exit.
4646
"""
4747

48-
def __init__(self, input_dataset: LSDBIterableDataset, tmp_dir: Path | str):
48+
def __init__(self, input_dataset: LSDBIterableDataset, tmp_dir: Path | str, cleanup: bool):
4949
self.input_dataset = input_dataset
5050
self.tmp_dir = Path(tmp_dir)
51+
self.cleanup = cleanup
5152

5253
def _serialize_data(self):
5354
n_chunks = 0
@@ -74,4 +75,5 @@ def __enter__(self) -> DataLoader:
7475
)
7576

7677
def __exit__(self, exc_type, exc_val, exc_tb):
77-
shutil.rmtree(self.tmp_dir)
78+
if self.cleanup:
79+
shutil.rmtree(self.tmp_dir)

src/uncle_val/pipelines/plotting.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -212,8 +212,11 @@ def _get_hists(
212212
device=torch.device(device),
213213
)
214214

215-
with Client(n_workers=n_workers, threads_per_worker=1, memory_limit="8GB") as client:
216-
print(f"Dask Dashboard Link: {client.dashboard_link}")
215+
with Client(n_workers=n_workers, threads_per_worker=1, memory_limit="64GB") as client:
216+
try:
217+
print(f"Dask Dashboard Link: {client.dashboard_link}")
218+
except KeyError as e:
219+
print(f"Cannot get Dask Dashboard Link: {e}")
217220
hists_df = hists.compute()
218221

219222
return hists_df

src/uncle_val/pipelines/training_loop.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,10 @@ def training_loop(
129129
model = model.to(device)
130130

131131
with Client(n_workers=n_workers, memory_limit="8GB", threads_per_worker=1) as client:
132-
print(f"Dask Dashboard Link: {client.dashboard_link}")
132+
try:
133+
print(f"Dask Dashboard Link: {client.dashboard_link}")
134+
except KeyError as e:
135+
print(f"Cannot get Dask Dashboard Link: {e}")
133136

134137
validation_dataset_lsdb = LSDBIterableDataset(
135138
catalog=catalog,
@@ -144,7 +147,9 @@ def training_loop(
144147
device=device,
145148
)
146149

147-
with MaterializedDataLoaderContext(validation_dataset_lsdb, tmp_validation_dir) as val_dataloader:
150+
with MaterializedDataLoaderContext(
151+
validation_dataset_lsdb, tmp_validation_dir, cleanup=False
152+
) as val_dataloader:
148153
val_stats_future: Future | None = None
149154
mean_val_loss_i = 0
150155

0 commit comments

Comments
 (0)