Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Pipfile
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ pytest = "*"
ruff = "*"
setuptools = "*"
pip-audit = "*"
pytest-freezegun = "*"

[requires]
python_version = "3.12"
18 changes: 17 additions & 1 deletion Pipfile.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

54 changes: 54 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,3 +153,57 @@ def dataset_with_runs_location(tmp_path) -> str:
@pytest.fixture
def local_dataset_with_runs(dataset_with_runs_location) -> TIMDEXDataset:
return TIMDEXDataset(dataset_with_runs_location)


@pytest.fixture
def dataset_with_same_day_runs(tmp_path) -> TIMDEXDataset:
"""Dataset fixture where a single source had multiple runs on the same day.

After these runs, we'd expect 70 records in Opensearch:
- most recent full run "run-2" established a 75 record base
- runs "run-3" and "run-4" just modified records; no record count change
- run "run-5" deleted 5 records

If the order of full runs 1 & 2 are not handled correctly, we'd see an incorrect
baseline of 100 records.

If the order of daily runs 4 & 5 are not handled correctly, we'd see 75 records
because the deletes would happen before the index just recreated the records.
"""
location = str(tmp_path / "dataset_with_same_day_runs")
os.mkdir(location)

timdex_dataset = TIMDEXDataset(location)

run_params = []

# Simulate two "full" runs where "run-2" should establish the baseline.
# Simulate daily runs, multiple per day sometimes, where deletes from "run-5" should
# be represented.
run_params.extend(
[
(100, "alma", "2025-01-01", "full", "index", "run-1"),
(75, "alma", "2025-01-01", "full", "index", "run-2"),
(10, "alma", "2025-01-01", "daily", "index", "run-3"),
(20, "alma", "2025-01-02", "daily", "index", "run-4"),
(5, "alma", "2025-01-02", "daily", "delete", "run-5"),
]
)

for params in run_params:
num_records, source, run_date, run_type, action, run_id = params
records = generate_sample_records(
num_records,
timdex_record_id_prefix=source,
source=source,
run_date=run_date,
run_type=run_type,
action=action,
run_id=run_id,
)
timdex_dataset.write(records)

# reload after writes
timdex_dataset.load()

return timdex_dataset
70 changes: 68 additions & 2 deletions tests/test_dataset.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
# ruff: noqa: D205, S105, S106, SLF001, PD901, PLR2004
# ruff: noqa: D205, D209, S105, S106, SLF001, PD901, PLR2004

import os
from datetime import date
from datetime import UTC, date, datetime
from unittest.mock import MagicMock, patch

import pyarrow as pa
import pytest
from pyarrow import fs

from tests.utils import generate_sample_records
from timdex_dataset_api.dataset import (
DatasetNotLoadedError,
TIMDEXDataset,
Expand Down Expand Up @@ -463,3 +464,68 @@ def test_dataset_current_records_index_filtering_accurate_records_yielded(
"alma:23",
"alma:24",
]


@pytest.mark.freeze_time("2025-05-22 01:23:45.567890")
def test_dataset_write_includes_minted_run_timestamp(tmp_path):
# create dataset
location = str(tmp_path / "one_run_at_frozen_time")
os.mkdir(location)
timdex_dataset = TIMDEXDataset(location)

run_id = "abc123"

# perform a single ETL run that should pickup the frozen time for run_timestamp
records = generate_sample_records(
10,
timdex_record_id_prefix="alma",
source="alma",
run_date="2025-05-22",
run_type="full",
action="index",
run_id=run_id,
)
timdex_dataset.write(records)
timdex_dataset.load()

# assert TIMDEXDataset.write() applies current time as run_timestamp
run_row_dict = next(timdex_dataset.read_dicts_iter())
assert "run_timestamp" in run_row_dict
assert run_row_dict["run_timestamp"] == datetime(
2025,
5,
22,
1,
23,
45,
567890,
tzinfo=UTC,
)

# assert the same run_timestamp is applied to all rows in the run
df = timdex_dataset.read_dataframe(run_id=run_id)
assert len(list(df.run_timestamp.unique())) == 1


def test_dataset_load_current_records_gets_correct_same_day_full_run(
dataset_with_same_day_runs,
):
"""Two full runs were performed on the same day, but 'run-2' was performed most
recently. current_records=True should discover the more recent of the two 'run-2',
not 'run-1'."""
dataset_with_same_day_runs.load(current_records=True, run_type="full")
df = dataset_with_same_day_runs.read_dataframe()

assert list(df.run_id.unique()) == ["run-2"]


def test_dataset_load_current_records_gets_correct_same_day_daily_runs_ordering(
dataset_with_same_day_runs,
):
"""Two runs were performed on 2025-01-02, but the most recent records should be from
run 'run-5' which are action='delete', not 'run-4' with action='index'."""
dataset_with_same_day_runs.load(current_records=True, run_type="daily")
first_record = next(dataset_with_same_day_runs.read_dicts_iter())

assert first_record["run_id"] == "run-5"
assert first_record["action"] == "delete"
29 changes: 20 additions & 9 deletions timdex_dataset_api/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
pa.field("year", pa.string()),
pa.field("month", pa.string()),
pa.field("day", pa.string()),
pa.field("run_timestamp", pa.timestamp("us", tz="UTC")),
)
)

Expand All @@ -62,6 +63,7 @@ class DatasetFilters(TypedDict, total=False):
year: str | None
month: str | None
day: str | None
run_timestamp: str | datetime | None


@dataclass
Expand Down Expand Up @@ -112,15 +114,19 @@ def __init__(
location (str | list[str]): Local filesystem path or an S3 URI to
a parquet dataset. For partitioned datasets, set to the base directory.
"""
self.location = location
self.config = config or TIMDEXDatasetConfig()
self.location = location

# pyarrow dataset
self.filesystem, self.paths = self.parse_location(self.location)
self.dataset: ds.Dataset = None # type: ignore[assignment]
self.schema = TIMDEX_DATASET_SCHEMA
self.partition_columns = TIMDEX_DATASET_PARTITION_COLUMNS

# writing
self._written_files: list[ds.WrittenFile] = None # type: ignore[assignment]

# reading
self._current_records: bool = False
self._current_records_dataset: ds.Dataset = None # type: ignore[assignment]

Expand Down Expand Up @@ -405,26 +411,31 @@ def write(
return self._written_files # type: ignore[return-value]

def create_record_batches(
self,
records_iter: Iterator["DatasetRecord"],
self, records_iter: Iterator["DatasetRecord"]
) -> Iterator[pa.RecordBatch]:
"""Yield pyarrow.RecordBatches for writing.

This method expects an iterator of DatasetRecord instances.

Each DatasetRecord is validated and serialized to a dictionary before added to a
pyarrow.RecordBatch for writing.
Each DatasetRecord is serialized to a dictionary, any column data shared by all
rows is added to the record, and then added to a pyarrow.RecordBatch for writing.

Args:
- records_iter: Iterator of DatasetRecord instances
"""
run_timestamp = datetime.now(UTC)
for i, record_batch in enumerate(
itertools.batched(records_iter, self.config.write_batch_size)
):
batch = pa.RecordBatch.from_pylist(
[record.to_dict() for record in record_batch]
)
logger.debug(f"Yielding batch {i+1} for dataset writing.")
record_dicts = [
{
**record.to_dict(),
"run_timestamp": run_timestamp,
}
for record in record_batch
]
batch = pa.RecordBatch.from_pylist(record_dicts)
logger.debug(f"Yielding batch {i + 1} for dataset writing.")
yield batch

def log_write_statistics(self, start_time: float) -> None:
Expand Down
7 changes: 5 additions & 2 deletions timdex_dataset_api/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ def get_runs_metadata(self, *, refresh: bool = False) -> pd.DataFrame:
"source": "first",
"run_date": "first",
"run_type": "first",
"run_timestamp": "first",
"num_rows": "sum",
"filename": list,
}
Expand All @@ -65,9 +66,9 @@ def get_runs_metadata(self, *, refresh: bool = False) -> pd.DataFrame:
lambda x: len(x)
)

# sort by run date and source
# sort by run_timestamp (more granularity than run_date) and source
grouped_runs_df = grouped_runs_df.sort_values(
["run_date", "source"], ascending=False
["run_timestamp", "source"], ascending=False
)

# cache the result
Expand Down Expand Up @@ -185,12 +186,14 @@ def _parse_run_metadata_from_parquet_file(self, parquet_filepath: str) -> dict:
run_date = columns_meta[4]["statistics"]["max"]
run_type = columns_meta[5]["statistics"]["max"]
run_id = columns_meta[7]["statistics"]["max"]
run_timestamp = columns_meta[9]["statistics"]["max"]

return {
"source": source,
"run_date": run_date,
"run_type": run_type,
"run_id": run_id,
"run_timestamp": run_timestamp,
"num_rows": num_rows,
"filename": parquet_filepath,
}