Skip to content

Commit fddfed0

Browse files
committed
Establish TIMDEXDataset class
Why these changes are being introduced: All operations for a dataset can be wrapped in a TIMDEXDataset class. This class would be responsible for loading a dataset that read and write operations would be performed against. How this addresses that need: * Establishes TIMDEXDataset class * Adds methods for loading a database * Support for local or S3 location for dataset Side effects of this change: * None Relevant ticket(s): * https://mitlibraries.atlassian.net/browse/TIMX-415
1 parent f889439 commit fddfed0

File tree

8 files changed

+270
-0
lines changed

8 files changed

+270
-0
lines changed

tests/conftest.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,18 @@
11
import pytest
22

3+
from timdex_dataset_api import TIMDEXDataset
4+
35

46
@pytest.fixture(autouse=True)
57
def _test_env(monkeypatch):
68
monkeypatch.setenv("TDA_LOG_LEVEL", "INFO")
9+
10+
11+
@pytest.fixture
12+
def local_dataset_location():
13+
return "tests/fixtures/local_datasets/dataset"
14+
15+
16+
@pytest.fixture
17+
def local_dataset(local_dataset_location):
18+
return TIMDEXDataset.load(local_dataset_location)

tests/test_dataset.py

Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,129 @@
1+
# ruff: noqa: S105, S106, SLF001
2+
3+
from unittest.mock import MagicMock, patch
4+
5+
import pytest
6+
from pyarrow import fs
7+
8+
from timdex_dataset_api.dataset import DatasetNotLoadedError, TIMDEXDataset
9+
10+
11+
@pytest.mark.parametrize(
12+
("location", "expected_filesystem", "expected_source"),
13+
[
14+
("/path/to/dataset", fs.LocalFileSystem, "/path/to/dataset"),
15+
(
16+
["/path/to/records1.parquet", "/path/to/records2.parquet"],
17+
fs.LocalFileSystem,
18+
["/path/to/records1.parquet", "/path/to/records2.parquet"],
19+
),
20+
("s3://bucket/path/to/dataset", fs.S3FileSystem, "bucket/path/to/dataset"),
21+
(
22+
[
23+
"s3://bucket/path/to/dataset/records1.parquet",
24+
"s3://bucket/path/to/dataset/records2.parquet",
25+
],
26+
fs.S3FileSystem,
27+
[
28+
"bucket/path/to/dataset/records1.parquet",
29+
"bucket/path/to/dataset/records2.parquet",
30+
],
31+
),
32+
],
33+
)
34+
@patch("timdex_dataset_api.dataset.TIMDEXDataset.get_s3_filesystem")
35+
def test_parse_location_single_local_directory(
36+
get_s3_filesystem,
37+
location,
38+
expected_filesystem,
39+
expected_source,
40+
):
41+
get_s3_filesystem.return_value = fs.S3FileSystem()
42+
filesystem, source = TIMDEXDataset.parse_location(location)
43+
assert isinstance(filesystem, expected_filesystem)
44+
assert source == expected_source
45+
46+
47+
@patch("timdex_dataset_api.dataset.fs.S3FileSystem")
48+
@patch("timdex_dataset_api.dataset.boto3.session.Session")
49+
def test_get_s3_filesystem_success(mock_session, mock_s3_filesystem):
50+
mock_credentials = MagicMock()
51+
mock_credentials.secret_key = "fake_secret_key"
52+
mock_credentials.access_key = "fake_access_key"
53+
mock_credentials.token = "fake_session_token"
54+
mock_session.return_value.get_credentials.return_value = mock_credentials
55+
mock_session.return_value.region_name = "us-east-1"
56+
57+
s3_filesystem = TIMDEXDataset.get_s3_filesystem()
58+
59+
mock_s3_filesystem.assert_called_once_with(
60+
secret_key="fake_secret_key",
61+
access_key="fake_access_key",
62+
region="us-east-1",
63+
session_token="fake_session_token",
64+
)
65+
assert isinstance(s3_filesystem, MagicMock)
66+
67+
68+
@patch("timdex_dataset_api.dataset.fs.LocalFileSystem")
69+
@patch("timdex_dataset_api.dataset.ds.dataset")
70+
def test_load_local_dataset_correct_filesystem_and_source(mock_pyarrow_ds, mock_local_fs):
71+
mock_local_fs.return_value = MagicMock()
72+
mock_pyarrow_ds.return_value = MagicMock()
73+
74+
timdex_dataset = TIMDEXDataset(location="local/path/to/dataset")
75+
loaded_dataset = timdex_dataset.load_dataset()
76+
77+
mock_pyarrow_ds.assert_called_once_with(
78+
"local/path/to/dataset",
79+
schema=timdex_dataset.schema,
80+
format="parquet",
81+
partitioning="hive",
82+
filesystem=mock_local_fs.return_value,
83+
)
84+
assert loaded_dataset == mock_pyarrow_ds.return_value
85+
86+
87+
@patch("timdex_dataset_api.dataset.TIMDEXDataset.get_s3_filesystem")
88+
@patch("timdex_dataset_api.dataset.ds.dataset")
89+
def test_load_s3_dataset_correct_filesystem_and_source(mock_pyarrow_ds, mock_get_s3_fs):
90+
mock_get_s3_fs.return_value = MagicMock()
91+
mock_pyarrow_ds.return_value = MagicMock()
92+
93+
timdex_dataset = TIMDEXDataset(location="s3://bucket/path/to/dataset")
94+
loaded_dataset = timdex_dataset.load_dataset()
95+
96+
mock_get_s3_fs.assert_called_once()
97+
mock_pyarrow_ds.assert_called_once_with(
98+
"bucket/path/to/dataset",
99+
schema=timdex_dataset.schema,
100+
format="parquet",
101+
partitioning="hive",
102+
filesystem=mock_get_s3_fs.return_value,
103+
)
104+
assert loaded_dataset == mock_pyarrow_ds.return_value
105+
106+
107+
@patch("timdex_dataset_api.dataset.TIMDEXDataset.load_dataset")
108+
def test_load_method_loads_dataset_and_returns_timdexdataset_instance(mock_load_dataset):
109+
mock_load_dataset.return_value = MagicMock()
110+
111+
timdex_dataset = TIMDEXDataset.load("s3://bucket/path/to/dataset")
112+
113+
assert isinstance(timdex_dataset, TIMDEXDataset)
114+
assert timdex_dataset.location == "s3://bucket/path/to/dataset"
115+
mock_load_dataset.assert_called_once()
116+
117+
118+
def test_local_dataset_is_valid(local_dataset):
119+
assert local_dataset.dataset.to_table().validate() is None # where None is valid
120+
121+
122+
def test_local_dataset_row_count_success(local_dataset):
123+
assert local_dataset.dataset.count_rows() == local_dataset.row_count
124+
125+
126+
def test_local_dataset_row_count_missing_dataset_exception(local_dataset):
127+
td = TIMDEXDataset(location="path/to/nowhere")
128+
with pytest.raises(DatasetNotLoadedError):
129+
_ = td.row_count

timdex_dataset_api/__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,9 @@
11
"""timdex_dataset_api/__init__.py"""
22

3+
from timdex_dataset_api.dataset import TIMDEXDataset
4+
35
__version__ = "0.1.0"
6+
7+
__all__ = [
8+
"TIMDEXDataset",
9+
]

timdex_dataset_api/dataset.py

Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,118 @@
1+
"""timdex_dataset_api/dataset.py"""
2+
3+
import time
4+
5+
import boto3
6+
import pyarrow as pa
7+
import pyarrow.dataset as ds
8+
from pyarrow import fs
9+
10+
from timdex_dataset_api.config import configure_logger
11+
from timdex_dataset_api.exceptions import DatasetNotLoadedError
12+
13+
logger = configure_logger(__name__)
14+
15+
TIMDEX_DATASET_SCHEMA = pa.schema(
16+
(
17+
pa.field("timdex_record_id", pa.string()),
18+
pa.field("source_record", pa.binary()),
19+
pa.field("transformed_record", pa.binary()),
20+
pa.field("source", pa.string()),
21+
pa.field("run_date", pa.date32()),
22+
pa.field("run_type", pa.string()),
23+
pa.field("run_id", pa.string()),
24+
pa.field("action", pa.string()),
25+
)
26+
)
27+
28+
29+
class TIMDEXDataset:
30+
31+
def __init__(self, location: str | list[str]):
32+
self.location = location
33+
self.dataset: ds.Dataset = None # type: ignore[assignment]
34+
self.schema = TIMDEX_DATASET_SCHEMA
35+
36+
@classmethod
37+
def load(cls, location: str) -> "TIMDEXDataset":
38+
"""Return an instantiated TIMDEXDataset object given a dataset location.
39+
40+
Argument 'location' may be a local filesystem path or an S3 URI to a parquet
41+
dataset.
42+
"""
43+
timdex_dataset = cls(location=location)
44+
timdex_dataset.dataset = timdex_dataset.load_dataset()
45+
return timdex_dataset
46+
47+
@staticmethod
48+
def get_s3_filesystem() -> fs.FileSystem:
49+
"""Instantiate a pyarrow S3 Filesystem for dataset loading."""
50+
session = boto3.session.Session()
51+
credentials = session.get_credentials()
52+
if not credentials:
53+
raise RuntimeError("Could not locate AWS credentials")
54+
return fs.S3FileSystem(
55+
secret_key=credentials.secret_key,
56+
access_key=credentials.access_key,
57+
region=session.region_name,
58+
session_token=credentials.token,
59+
)
60+
61+
@staticmethod
62+
def parse_location(
63+
location: str | list[str],
64+
) -> tuple[fs.FileSystem, str | list[str]]:
65+
"""Parse and return the filesystem and normalized source location(s).
66+
67+
Handles both single location strings and lists of Parquet file paths.
68+
"""
69+
source: str | list[str]
70+
if isinstance(location, str):
71+
if location.startswith("s3://"):
72+
filesystem = TIMDEXDataset.get_s3_filesystem()
73+
source = location.removeprefix("s3://")
74+
else:
75+
filesystem = fs.LocalFileSystem()
76+
source = location
77+
elif isinstance(location, list):
78+
if all(loc.startswith("s3://") for loc in location):
79+
filesystem = TIMDEXDataset.get_s3_filesystem()
80+
source = [loc.removeprefix("s3://") for loc in location]
81+
elif all(not loc.startswith("s3://") for loc in location):
82+
filesystem = fs.LocalFileSystem()
83+
source = location
84+
else:
85+
raise ValueError("Mixed S3 and local paths are not supported.")
86+
else:
87+
raise TypeError("Location type must be str or list[str].")
88+
89+
return filesystem, source
90+
91+
def load_dataset(self) -> ds.Dataset:
92+
"""Lazy load a pyarrow.Dataset for an already instantiated TIMDEXDataset object.
93+
94+
The dataset is loaded via the expected schema as defined by module constant
95+
TIMDEX_DATASET_SCHEMA. If the target dataset differs in any way, errors may be
96+
raised when reading or writing data.
97+
"""
98+
start_time = time.perf_counter()
99+
filesystem, source = self.parse_location(self.location)
100+
dataset = ds.dataset(
101+
source,
102+
schema=self.schema,
103+
format="parquet",
104+
partitioning="hive",
105+
filesystem=filesystem,
106+
)
107+
logger.info(
108+
f"Dataset successfully loaded: '{self.location}', "
109+
f"{round(time.perf_counter()-start_time, 2)}s"
110+
)
111+
return dataset
112+
113+
@property
114+
def row_count(self) -> int:
115+
"""Get row count from loaded dataset."""
116+
if not self.dataset:
117+
raise DatasetNotLoadedError
118+
return self.dataset.count_rows()

timdex_dataset_api/exceptions.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
"""timdex_dataset_api/exceptions.py"""
2+
3+
4+
class DatasetNotLoadedError(Exception):
5+
"""Custom exception for accessing methods requiring a loaded dataset."""

0 commit comments

Comments
 (0)