Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 67 additions & 1 deletion tests/test_dataset.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# ruff: noqa: S105, S106, SLF001, PLR2004
# ruff: noqa: D205, S105, S106, SLF001, PD901, PLR2004

import os
from datetime import date
from unittest.mock import MagicMock, patch
Expand Down Expand Up @@ -397,3 +398,68 @@ def test_dataset_all_read_methods_get_deduplication(
transformed_records = list(local_dataset_with_runs.read_transformed_records_iter())

assert len(full_df) == len(all_records) == len(transformed_records)


def test_dataset_current_records_no_additional_filtering_accurate_records_yielded(
local_dataset_with_runs,
):
local_dataset_with_runs.load(current_records=True, source="alma")
df = local_dataset_with_runs.read_dataframe()
assert df.action.value_counts().to_dict() == {"index": 99, "delete": 1}


def test_dataset_current_records_action_filtering_accurate_records_yielded(
local_dataset_with_runs,
):
local_dataset_with_runs.load(current_records=True, source="alma")
df = local_dataset_with_runs.read_dataframe(action="index")
assert df.action.value_counts().to_dict() == {"index": 99}


def test_dataset_current_records_index_filtering_accurate_records_yielded(
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this test is the most verbose and percise in pinpointing how these changes address the issue.

local_dataset_with_runs,
):
"""This is a somewhat complex test, but demonstrates that only 'current' records
are yielded when .load(current_records=True) is applied.

Given these runs from the fixture:
[
...
(25, "alma", "2025-01-03", "daily", "index", "run-5"), <---- filtered to
(10, "alma", "2025-01-04", "daily", "delete", "run-6"), <---- influences current
...
]

Though we are filtering to run-5, which has 25 total records to-index, we see only 15
records yielded. Why? This is because while we have filtered to only yield from
run-5, run-6 had 10 deletes which made records alma:0|9 no longer "current" in run-5.
As we yielded records reverse chronologically, the deletes from run-6 (alma:0-alma:9)
"influenced" what records we would see as we continue backwards in time.
"""
# with current_records=False, we get all 25 records from run-5
local_dataset_with_runs.load(current_records=False, source="alma")
df = local_dataset_with_runs.read_dataframe(run_id="run-5")
assert len(df) == 25

# with current_records=True, we only get 15 records from run-5
# because newer run-6 influenced what records are current for older run-5
local_dataset_with_runs.load(current_records=True, source="alma")
df = local_dataset_with_runs.read_dataframe(run_id="run-5")
assert len(df) == 15
assert list(df.timdex_record_id) == [
"alma:10",
"alma:11",
"alma:12",
"alma:13",
"alma:14",
"alma:15",
"alma:16",
"alma:17",
"alma:18",
"alma:19",
"alma:20",
"alma:21",
"alma:22",
"alma:23",
"alma:24",
]
8 changes: 5 additions & 3 deletions tests/test_runs.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,19 +12,21 @@
@pytest.fixture
def timdex_run_manager(dataset_with_runs_location):
timdex_dataset = TIMDEXDataset(dataset_with_runs_location)
return TIMDEXRunManager(timdex_dataset=timdex_dataset)
timdex_dataset.load()
return TIMDEXRunManager(dataset=timdex_dataset.dataset)


def test_timdex_run_manager_init(dataset_with_runs_location):
timdex_dataset = TIMDEXDataset(dataset_with_runs_location)
timdex_run_manager = TIMDEXRunManager(timdex_dataset=timdex_dataset)
timdex_dataset.load()
timdex_run_manager = TIMDEXRunManager(dataset=timdex_dataset.dataset)
assert timdex_run_manager._runs_metadata_cache is None


def test_timdex_run_manager_parse_single_parquet_file_success(timdex_run_manager):
"""Parse run metadata from first parquet file in fixture dataset. We know the details
of this ETL run in advance given the deterministic fixture that generated it."""
parquet_filepath = timdex_run_manager.timdex_dataset.dataset.files[0]
parquet_filepath = timdex_run_manager.dataset.files[0]
run_metadata = timdex_run_manager._parse_run_metadata_from_parquet_file(
parquet_filepath
)
Expand Down
137 changes: 85 additions & 52 deletions timdex_dataset_api/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,9 @@ def __init__(
self.schema = TIMDEX_DATASET_SCHEMA
self.partition_columns = TIMDEX_DATASET_PARTITION_COLUMNS
self._written_files: list[ds.WrittenFile] = None # type: ignore[assignment]
self._dedupe_on_read: bool = False

self._current_records: bool = False
self._current_records_dataset: ds.Dataset = None # type: ignore[assignment]

@property
def row_count(self) -> int:
Expand Down Expand Up @@ -153,27 +155,32 @@ def load(
- filters: kwargs typed via DatasetFilters TypedDict
- Filters passed directly in method call, e.g. source="alma",
run_date="2024-12-20", etc., but are typed according to DatasetFilters.
- current_records: bool
- if True, the TIMDEXRunManager will be used to retrieve a list of parquet
files associated with current runs, some internal flags will be set, all
ensuring that only current records are yielded for any read methods
"""
start_time = time.perf_counter()

# reset paths from original location before load
_, self.paths = self.parse_location(self.location)

# perform initial load of full dataset
self._load_pyarrow_dataset()
self.dataset = self._load_pyarrow_dataset()

# if current_records flag set, limit to parquet files associated with current runs
self._dedupe_on_read = current_records
self._current_records = current_records
if current_records:
timdex_run_manager = TIMDEXRunManager(timdex_dataset=self)

# update paths, limiting by source if set
timdex_run_manager = TIMDEXRunManager(dataset=self.dataset)
self.paths = timdex_run_manager.get_current_parquet_files(
source=filters.get("source")
)

# reload pyarrow dataset
self._load_pyarrow_dataset()
# reload pyarrow dataset, filtered now to an explicit list of parquet files
# also save an instance of the dataset before any additional filtering
dataset = self._load_pyarrow_dataset()
self.dataset = dataset
self._current_records_dataset = dataset

# filter dataset
self.dataset = self._get_filtered_dataset(**filters)
Expand All @@ -183,9 +190,9 @@ def load(
f"{round(time.perf_counter()-start_time, 2)}s"
)

def _load_pyarrow_dataset(self) -> None:
def _load_pyarrow_dataset(self) -> ds.Dataset:
"""Load the pyarrow dataset per local filesystem and paths attributes."""
self.dataset = ds.dataset(
return ds.dataset(
self.paths,
schema=self.schema,
format="parquet",
Expand Down Expand Up @@ -449,19 +456,14 @@ def read_batches_iter(
"""Yield pyarrow.RecordBatches from the dataset.

While batch_size will limit the max rows per batch, filtering may result in some
batches have less than this limit.
batches having less than this limit.

If the flag self._current_records is set, this method leans on
self._yield_current_record_deduped_batches() to apply deduplication of records to
ensure only current versions of the record are ever yielded.

Args:
- columns: list[str], list of columns to return from the dataset
- batch_size: int, max number of rows to yield per batch
- batch_read_ahead: int, the number of batches to read ahead in a file. This
might not work for all file formats. Increasing this number will increase
RAM usage but could also improve IO utilization. Pyarrow default is 16,
but this library defaults to 0 to prioritize memory footprint.
- fragment_read_ahead: int, The number of files to read ahead. Increasing this
number will increase RAM usage but could also improve IO utilization.
Pyarrow default is 4, but this library defaults to 0 to prioritize memory
footprint.
- filters: pairs of column:value to filter the dataset
"""
if not self.dataset:
Expand All @@ -477,47 +479,78 @@ def read_batches_iter(
fragment_readahead=self.config.fragment_read_ahead,
)

if self._dedupe_on_read:
yield from self._yield_deduped_batches(batches)
if self._current_records:
yield from self._yield_current_record_batches(batches)
else:
for batch in batches:
if len(batch) > 0:
yield batch

def _yield_deduped_batches(
self, batches: Iterator[pa.RecordBatch]
def _yield_current_record_batches(
self,
batches: Iterator[pa.RecordBatch],
) -> Iterator[pa.RecordBatch]:
"""Method to yield record deduped batches.
"""Method to yield only the most recent version of each record.

When multiple versions of a record (same timdex_record_id) exist in the dataset,
this method ensures only the most recent version is returned. If filtering is
applied that removes this most recent version of a record, that timdex_record_id
will not be yielded at all.

This is achieved by iterating over TWO record batch iterators in parallel:

1. "batches" - the RecordBatch iterator passed to this method which
contains the actual records and columns we are interested in, and may contain
filtering

2. "unfiltered_batches" - a lightweight RecordBatch iterator that only
contains the 'timdex_record_id' column from a pre-filtering dataset saved
during .load()

These two iterators are guaranteed to have the same number of total batches based
on how pyarrow.Dataset.to_batches() reads from parquet files. Even if dataset
filtering is applied, this does not affect the batch count; you may just end up
with smaller or empty batches.

Extending the normal behavior of yielding batches untouched, this method keeps
track of seen timdex_record_id's, yielding them only once. For this method to
yield the most current version of a record -- most common usage -- it is required
that the batches are pre-ordered so the most recent record version is encountered
first.
As we move through the record batches we use unfiltered batches to keep a list of
seen timdex_record_ids. Even if a timdex_record_is not in the record batch --
likely due to filtering -- we will mark that timdex_record_id as "seen" and not
yield it from any future batches.

Args:
- batches: batches of records to actually yield from
- current_record_id_batches: batches of timdex_record_id's that inform when
to yield or skip a record for a batch
"""
unfiltered_batches = self._current_records_dataset.to_batches(
columns=["timdex_record_id"],
batch_size=self.config.read_batch_size,
batch_readahead=self.config.batch_read_ahead,
fragment_readahead=self.config.fragment_read_ahead,
)

seen_records = set()
for batch in batches:
if len(batch) > 0:
# init list of batch indices for records unseen
unseen_batch_indices = []

# get list of timdex ids from batch
timdex_ids = batch.column("timdex_record_id").to_pylist()

# check each record id and track unseen ones
for i, record_id in enumerate(timdex_ids):
if record_id not in seen_records:
unseen_batch_indices.append(i)
seen_records.add(record_id)

# if all records from batch were seen, continue
if not unseen_batch_indices:
continue

# else, yield unseen records from batch
deduped_batch = batch.take(pa.array(unseen_batch_indices)) # type: ignore[arg-type]
if len(deduped_batch) > 0:
yield deduped_batch
for unfiltered_batch, batch in zip(unfiltered_batches, batches, strict=True):
# init list of indices from the batch for records we have never yielded
unseen_batch_indices = []

# check each record id and track unseen ones
for i, record_id in enumerate(batch.column("timdex_record_id").to_pylist()):
if record_id not in seen_records:
unseen_batch_indices.append(i)

# even if not a record to yield, update our list of seen records from all
# records in the unfiltered batch
seen_records.update(unfiltered_batch.column("timdex_record_id").to_pylist())

# if no unseen records from this batch, skip yielding entirely
if not unseen_batch_indices:
continue

# create a new RecordBatch using the unseen indices of the batch
_batch = batch.take(pa.array(unseen_batch_indices)) # type: ignore[arg-type]
if len(_batch) > 0:
yield _batch

def read_dataframes_iter(
self,
Expand Down
16 changes: 5 additions & 11 deletions timdex_dataset_api/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,25 +3,19 @@
import concurrent.futures
import logging
import time
from typing import TYPE_CHECKING

import pandas as pd
import pyarrow.dataset as ds
import pyarrow.parquet as pq

if TYPE_CHECKING:
from timdex_dataset_api.dataset import TIMDEXDataset

logger = logging.getLogger(__name__)


class TIMDEXRunManager:
"""Manages and provides access to ETL run metadata from the TIMDEX parquet dataset."""

def __init__(self, timdex_dataset: "TIMDEXDataset"):
self.timdex_dataset: TIMDEXDataset = timdex_dataset
if self.timdex_dataset.dataset is None:
self.timdex_dataset.load()

def __init__(self, dataset: ds.Dataset):
self.dataset = dataset
self._runs_metadata_cache: pd.DataFrame | None = None

def clear_cache(self) -> None:
Expand Down Expand Up @@ -143,7 +137,7 @@ def _get_parquet_files_run_metadata(self, max_workers: int = 250) -> pd.DataFram
"""
with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
futures = []
for parquet_filepath in self.timdex_dataset.dataset.files: # type: ignore[attr-defined]
for parquet_filepath in self.dataset.files: # type: ignore[attr-defined]
future = executor.submit(
self._parse_run_metadata_from_parquet_file,
parquet_filepath,
Expand Down Expand Up @@ -181,7 +175,7 @@ def _parse_run_metadata_from_parquet_file(self, parquet_filepath: str) -> dict:
"""
parquet_file = pq.ParquetFile(
parquet_filepath,
filesystem=self.timdex_dataset.filesystem,
filesystem=self.dataset.filesystem, # type: ignore[attr-defined]
)

file_meta = parquet_file.metadata.to_dict()
Expand Down