Skip to content

Commit 605dc05

Browse files
committed
Extend load to optionally limit to current parquet files
Why these changes are being introduced: With the creation of TIMDEXRunManager we now have the ability to identify parquet files associated with current ETL runs for all or a given source. This could be used to limit the TIMDEXDataset on load to only read from those parquet files. How this addresses that need: * Updates TIMDEXDataset.load() with a new 'current_records' flag that if True will use TIMDEXRunManager to get a list of parquet files to upload the dataset paths with. Side effects of this change: * None without explicit use. Eventually, could be utilized by contexts where only parquet files associated with current ETL runs are needed. Relevant ticket(s): * https://mitlibraries.atlassian.net/browse/TIMX-494
1 parent bd3b937 commit 605dc05

File tree

4 files changed

+81
-18
lines changed

4 files changed

+81
-18
lines changed

tests/test_dataset.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
def test_dataset_init_success(location, expected_file_system, expected_source):
2525
timdex_dataset = TIMDEXDataset(location=location)
2626
assert isinstance(timdex_dataset.filesystem, expected_file_system)
27-
assert timdex_dataset.source == expected_source
27+
assert timdex_dataset.paths == expected_source
2828

2929

3030
def test_dataset_init_env_vars_set_config(monkeypatch, local_dataset_location):
@@ -79,8 +79,7 @@ def test_dataset_load_s3_sets_filesystem_and_dataset_success(
7979
timdex_dataset = TIMDEXDataset(location="s3://bucket/path/to/dataset")
8080
result = timdex_dataset.load()
8181

82-
mock_get_s3_fs.assert_called_once()
83-
mock_pyarrow_ds.assert_called_once_with(
82+
mock_pyarrow_ds.assert_called_with(
8483
"bucket/path/to/dataset",
8584
schema=timdex_dataset.schema,
8685
format="parquet",
@@ -137,6 +136,22 @@ def test_dataset_load_with_multi_nonpartition_filters_success(fixed_local_datase
137136
assert fixed_local_dataset.row_count == 1
138137

139138

139+
def test_dataset_load_current_records_all_sources_success(dataset_with_runs_location):
140+
timdex_dataset = TIMDEXDataset(dataset_with_runs_location)
141+
timdex_dataset.load(current_records=True)
142+
143+
# 14 total parquet files, only 12 related to current runs
144+
assert len(timdex_dataset.dataset.files) == 12
145+
146+
147+
def test_dataset_load_current_records_one_source_success(dataset_with_runs_location):
148+
timdex_dataset = TIMDEXDataset(dataset_with_runs_location)
149+
timdex_dataset.load(current_records=True, source="alma")
150+
151+
# 7 total parquet files for source, only 6 related to current runs
152+
assert len(timdex_dataset.dataset.files) == 6
153+
154+
140155
def test_dataset_get_filtered_dataset_with_single_nonpartition_success(
141156
fixed_local_dataset,
142157
):

tests/test_runs.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,14 +56,27 @@ def test_timdex_run_manager_get_runs_df(timdex_run_manager):
5656
assert runs_df.source.value_counts().to_dict() == {"alma": 7, "dspace": 7}
5757

5858

59+
def test_timdex_run_manager_get_all_current_run_parquet_files_success(
60+
timdex_run_manager,
61+
):
62+
ordered_parquet_files = timdex_run_manager.get_current_parquet_files()
63+
64+
# assert 12 parquet files, despite being 14 total for ALL sources
65+
# this represents the last full run and all daily since
66+
assert len(ordered_parquet_files) == 12
67+
68+
# assert sorted reverse chronologically
69+
assert "year=2025/month=01/day=01" in ordered_parquet_files[-1]
70+
71+
5972
def test_timdex_run_manager_get_source_current_run_parquet_files_success(
6073
timdex_run_manager,
6174
):
6275
ordered_parquet_files = timdex_run_manager.get_current_source_parquet_files("alma")
6376

64-
# assert 6 parquet files, despite being 8 total for alma
77+
# assert 6 parquet files, despite being 8 total for 'alma' source
6578
# this represents the last full run and all daily since
66-
assert len(ordered_parquet_files)
79+
assert len(ordered_parquet_files) == 6
6780

6881
# assert sorted reverse chronologically
6982
assert "year=2025/month=01/day=05" in ordered_parquet_files[0]

timdex_dataset_api/dataset.py

Lines changed: 35 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020

2121
from timdex_dataset_api.config import configure_logger
2222
from timdex_dataset_api.exceptions import DatasetNotLoadedError
23+
from timdex_dataset_api.run import TIMDEXRunManager
2324

2425
if TYPE_CHECKING:
2526
from timdex_dataset_api.record import DatasetRecord # pragma: nocover
@@ -114,7 +115,7 @@ def __init__(
114115
self.location = location
115116
self.config = config or TIMDEXDatasetConfig()
116117

117-
self.filesystem, self.source = self.parse_location(self.location)
118+
self.filesystem, self.paths = self.parse_location(self.location)
118119
self.dataset: ds.Dataset = None # type: ignore[assignment]
119120
self.schema = TIMDEX_DATASET_SCHEMA
120121
self.partition_columns = TIMDEX_DATASET_PARTITION_COLUMNS
@@ -129,6 +130,8 @@ def row_count(self) -> int:
129130

130131
def load(
131132
self,
133+
*,
134+
current_records: bool = False,
132135
**filters: Unpack[DatasetFilters],
133136
) -> None:
134137
"""Lazy load a pyarrow.dataset.Dataset and set to self.dataset.
@@ -152,14 +155,23 @@ def load(
152155
"""
153156
start_time = time.perf_counter()
154157

155-
# load dataset
156-
self.dataset = ds.dataset(
157-
self.source,
158-
schema=self.schema,
159-
format="parquet",
160-
partitioning="hive",
161-
filesystem=self.filesystem,
162-
)
158+
# reset paths from original location before load
159+
_, self.paths = self.parse_location(self.location)
160+
161+
# perform initial load of full dataset
162+
self._load_pyarrow_dataset()
163+
164+
# if current_records flag set, limit to parquet files associated with current runs
165+
if current_records:
166+
timdex_run_manager = TIMDEXRunManager(timdex_dataset=self)
167+
168+
# if filters.source is set, further limit to only this source
169+
source = filters.get("source")
170+
if source:
171+
self.paths = timdex_run_manager.get_current_source_parquet_files(source)
172+
else:
173+
self.paths = timdex_run_manager.get_current_parquet_files()
174+
self._load_pyarrow_dataset()
163175

164176
# filter dataset
165177
self.dataset = self._get_filtered_dataset(**filters)
@@ -169,6 +181,16 @@ def load(
169181
f"{round(time.perf_counter()-start_time, 2)}s"
170182
)
171183

184+
def _load_pyarrow_dataset(self) -> None:
185+
"""Load the pyarrow dataset per local filesystem and paths attributes."""
186+
self.dataset = ds.dataset(
187+
self.paths,
188+
schema=self.schema,
189+
format="parquet",
190+
partitioning="hive",
191+
filesystem=self.filesystem,
192+
)
193+
172194
def _get_filtered_dataset(
173195
self,
174196
**filters: Unpack[DatasetFilters],
@@ -345,7 +367,8 @@ def write(
345367
start_time = time.perf_counter()
346368
self._written_files = []
347369

348-
if isinstance(self.source, list):
370+
dataset_filesystem, dataset_path = self.parse_location(self.location)
371+
if isinstance(dataset_path, list):
349372
raise TypeError(
350373
"Dataset location must be the root of a single dataset for writing"
351374
)
@@ -354,10 +377,10 @@ def write(
354377

355378
ds.write_dataset(
356379
record_batches_iter,
357-
base_dir=self.source,
380+
base_dir=dataset_path,
358381
basename_template="%s-{i}.parquet" % (str(uuid.uuid4())), # noqa: UP031
359382
existing_data_behavior="overwrite_or_ignore",
360-
filesystem=self.filesystem,
383+
filesystem=dataset_filesystem,
361384
file_visitor=lambda written_file: self._written_files.append(written_file), # type: ignore[arg-type]
362385
format="parquet",
363386
max_open_files=500,

timdex_dataset_api/run.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,17 @@ def get_current_source_parquet_files(self, source: str) -> list[str]:
115115

116116
return ordered_parquet_files
117117

118+
def get_current_parquet_files(self) -> list[str]:
119+
"""Get reverse chronological list of current parquet files for ALL sources."""
120+
runs_df = self.get_runs_metadata() # run metadata is cached for future calls
121+
sources = list(runs_df.source.unique())
122+
123+
source_parquet_files = []
124+
for source in sources:
125+
source_parquet_files.extend(self.get_current_source_parquet_files(source))
126+
127+
return source_parquet_files
128+
118129
def _get_parquet_files_run_metadata(self, max_workers: int = 250) -> pd.DataFrame:
119130
"""Retrieve run metadata from parquet file(s) in dataset.
120131
@@ -166,8 +177,9 @@ def _parse_run_metadata_from_parquet_file(self, parquet_filepath: str) -> dict:
166177
"""
167178
parquet_file = pq.ParquetFile(
168179
parquet_filepath,
169-
filesystem=self.timdex_dataset.filesystem, # type: ignore[union-attr]
180+
filesystem=self.timdex_dataset.filesystem,
170181
)
182+
171183
file_meta = parquet_file.metadata.to_dict()
172184
num_rows = file_meta["num_rows"]
173185
columns_meta = file_meta["row_groups"][0]["columns"] # type: ignore[typeddict-item]

0 commit comments

Comments
 (0)