Skip to content

Commit 93e22ec

Browse files
authored
Cloud ingestion optimizations (#738)
* Working tests with fake server * Fix possible import issues * Cleanup * Strip out local gcs fake server and revert some regressions * Put max workers config back in worker processes as setting globally did not appear to be fully honored * Reordering to match original code * pre-commit and cleanup warnings * Remove unnecessary global `SegyFile`
1 parent e10c50c commit 93e22ec

File tree

2 files changed

+53
-43
lines changed

2 files changed

+53
-43
lines changed

src/mdio/segy/_workers.py

Lines changed: 17 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -8,17 +8,15 @@
88
import numpy as np
99
from segy.arrays import HeaderArray
1010

11-
from mdio.api.io import _normalize_storage_options
1211
from mdio.core.config import MDIOSettings
1312
from mdio.segy._raw_trace_wrapper import SegyFileRawTraceWrapper
1413
from mdio.segy.file import SegyFileArguments
1514
from mdio.segy.file import SegyFileWrapper
1615

1716
if TYPE_CHECKING:
18-
from upath import UPath
17+
from segy import SegyFile
1918
from zarr import Array as zarr_Array
2019

21-
from zarr import open_group as zarr_open_group
2220
from zarr.core.config import config as zarr_config
2321

2422
from mdio.builder.schemas.v1.stats import CenteredBinHistogram
@@ -71,26 +69,30 @@ def header_scan_worker(
7169

7270

7371
def trace_worker( # noqa: PLR0913
74-
segy_file_kwargs: SegyFileArguments,
75-
output_path: UPath,
76-
data_variable_name: str,
72+
segy_file: SegyFile,
73+
data_array: zarr_Array,
74+
header_array: zarr_Array | None,
75+
raw_header_array: zarr_Array | None,
7776
region: dict[str, slice],
7877
grid_map: zarr_Array,
7978
) -> SummaryStatistics | None:
8079
"""Writes a subset of traces from a region of the dataset of Zarr file.
8180
8281
Args:
83-
segy_file_kwargs: Arguments to open SegyFile instance.
84-
output_path: Universal Path for the output Zarr dataset
85-
(e.g. local file path or cloud storage URI) the location
86-
also includes storage options for cloud storage.
87-
data_variable_name: Name of the data variable to write.
82+
segy_file: The opened SEG-Y file.
83+
data_array: Zarr array for writing trace data.
84+
header_array: Zarr array for writing trace headers (or None if not needed).
85+
raw_header_array: Zarr array for writing raw headers (or None if not needed).
8886
region: Region of the dataset to write to.
8987
grid_map: Zarr array mapping live traces to their positions in the dataset.
9088
9189
Returns:
9290
SummaryStatistics object containing statistics about the written traces.
9391
"""
92+
# Setting the zarr config to 1 thread to ensure we honor the `MDIO__IMPORT__CPU_COUNT` environment variable.
93+
# The Zarr 3 engine utilizes multiple threads. This can lead to resource contention and unpredictable memory usage.
94+
zarr_config.set({"threading.max_workers": 1})
95+
9496
region_slices = tuple(region.values())
9597
local_grid_map = grid_map[region_slices[:-1]] # minus last (vertical) axis
9698

@@ -100,26 +102,8 @@ def trace_worker( # noqa: PLR0913
100102
if not not_null.any():
101103
return None
102104

103-
# Open the SEG-Y file in this process since the open file handles cannot be shared across processes.
104-
segy_file = SegyFileWrapper(**segy_file_kwargs)
105-
106-
# Setting the zarr config to 1 thread to ensure we honor the `MDIO__IMPORT__MAX_WORKERS` environment variable.
107-
# The Zarr 3 engine utilizes multiple threads. This can lead to resource contention and unpredictable memory usage.
108-
zarr_config.set({"threading.max_workers": 1})
109-
110105
live_trace_indexes = local_grid_map[not_null].tolist()
111106

112-
# Open the zarr group to write directly
113-
storage_options = _normalize_storage_options(output_path)
114-
zarr_group = zarr_open_group(output_path.as_posix(), mode="r+", storage_options=storage_options)
115-
116-
header_key = "headers"
117-
raw_header_key = "raw_headers"
118-
119-
# Check which variables exist in the zarr store
120-
available_arrays = list(zarr_group.array_keys())
121-
122-
# traces = segy_file.trace[live_trace_indexes]
123107
# Raw headers are not intended to remain as a feature of the SEGY ingestion.
124108
# For that reason, we have wrapped the accessors to provide an interface that can be removed
125109
# and not require additional changes to the below code.
@@ -132,24 +116,21 @@ def trace_worker( # noqa: PLR0913
132116
full_shape = tuple(s.stop - s.start for s in region_slices)
133117
header_shape = tuple(s.stop - s.start for s in header_region_slices)
134118

135-
# Write raw headers if they exist
119+
# Write raw headers if array was provided
136120
# Headers only have spatial dimensions (no sample dimension)
137-
if raw_header_key in available_arrays:
138-
raw_header_array = zarr_group[raw_header_key]
121+
if raw_header_array is not None:
139122
tmp_raw_headers = np.full(header_shape, raw_header_array.fill_value)
140123
tmp_raw_headers[not_null] = traces.raw_header
141124
raw_header_array[header_region_slices] = tmp_raw_headers
142125

143-
# Write headers if they exist
126+
# Write headers if array was provided
144127
# Headers only have spatial dimensions (no sample dimension)
145-
if header_key in available_arrays:
146-
header_array = zarr_group[header_key]
128+
if header_array is not None:
147129
tmp_headers = np.full(header_shape, header_array.fill_value)
148130
tmp_headers[not_null] = traces.header
149131
header_array[header_region_slices] = tmp_headers
150132

151133
# Write the data variable
152-
data_array = zarr_group[data_variable_name]
153134
tmp_samples = np.full(full_shape, data_array.fill_value)
154135
tmp_samples[not_null] = traces.sample
155136
data_array[region_slices] = tmp_samples

src/mdio/segy/blocked_io.py

Lines changed: 36 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
import zarr
1313
from dask.array import Array
1414
from dask.array import map_blocks
15+
from segy import SegyFile
1516
from tqdm.auto import tqdm
1617
from zarr import open_group as zarr_open_group
1718

@@ -80,18 +81,48 @@ def to_zarr( # noqa: PLR0913, PLR0915
8081
chunk_iter = ChunkIterator(shape=data.shape, chunks=worker_chunks, dim_names=data.dims)
8182
num_chunks = chunk_iter.num_chunks
8283

84+
zarr_format = zarr.config.get("default_zarr_format")
85+
86+
# Open zarr group once in main process
87+
storage_options = _normalize_storage_options(output_path)
88+
zarr_group = zarr_open_group(
89+
output_path.as_posix(),
90+
mode="r+",
91+
storage_options=storage_options,
92+
use_consolidated=zarr_format == ZarrFormat.V2,
93+
)
94+
95+
# Get array handles from the opened group
96+
data_array = zarr_group[data_variable_name]
97+
header_array = zarr_group.get("headers")
98+
raw_header_array = zarr_group.get("raw_headers")
99+
83100
# For Unix async writes with s3fs/fsspec & multiprocessing, use 'spawn' instead of default
84101
# 'fork' to avoid deadlocks on cloud stores. Slower but necessary. Default on Windows.
85102
num_workers = min(num_chunks, settings.import_cpus)
86103
context = mp.get_context("spawn")
87-
executor = ProcessPoolExecutor(max_workers=num_workers, mp_context=context)
104+
105+
# Use initializer to open segy file once per worker
106+
executor = ProcessPoolExecutor(
107+
max_workers=num_workers,
108+
mp_context=context,
109+
)
110+
111+
segy_file = SegyFile(**segy_file_kwargs)
88112

89113
with executor:
90114
futures = []
91-
common_args = (segy_file_kwargs, output_path, data_variable_name)
92115
for region in chunk_iter:
93-
subset_args = (region, grid_map)
94-
future = executor.submit(trace_worker, *common_args, *subset_args)
116+
# Pass zarr array handles directly to workers
117+
future = executor.submit(
118+
trace_worker,
119+
segy_file,
120+
data_array,
121+
header_array,
122+
raw_header_array,
123+
region,
124+
grid_map,
125+
)
95126
futures.append(future)
96127

97128
iterable = tqdm(
@@ -106,11 +137,9 @@ def to_zarr( # noqa: PLR0913, PLR0915
106137
if result is not None:
107138
_update_stats(final_stats, result)
108139

140+
# Update statistics using the already-open zarr group
109141
# Xarray doesn't directly support incremental attribute updates when appending to an existing Zarr store.
110142
# HACK: We will update the array attribute using zarr's API directly.
111-
# Use the data_variable_name to get the array in the Zarr group and write "statistics" metadata there
112-
storage_options = _normalize_storage_options(output_path)
113-
zarr_group = zarr_open_group(output_path.as_posix(), mode="a", storage_options=storage_options)
114143
attr_json = final_stats.model_dump_json()
115144
zarr_group[data_variable_name].attrs.update({"statsV1": attr_json})
116145

0 commit comments

Comments
 (0)