Skip to content

Commit 1162835

Browse files
committed
Add dataset writing functionality
Why these changes are being introduced: A primary use case for the library will be Transmogrifier writing new records to the parquet dataset. This library is intended to make that work simple for Transmogrifier, where all it needs to do is yield DatasetRecords (imported from this lib) to the write method. How this addresses that need: * Adds new entrypoint write() method * Includes helper methods to batch records yielded to write method * Adds DatasetRecord class that is designed to encapsulat each record (row) that will get written to the dataset. Side effects of this change: * Library supports writing to local or remote dataset Relevant ticket(s): * https://mitlibraries.atlassian.net/browse/TIMX-415
1 parent 69ae96b commit 1162835

File tree

8 files changed

+466
-3
lines changed

8 files changed

+466
-3
lines changed

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,7 @@ ignore = [
8686
"D103",
8787
"D104",
8888
"D415",
89+
"D417",
8990
"EM102",
9091
"G004",
9192
"PLR0912",

tests/conftest.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,11 @@
11
"""tests/conftest.py"""
22

3+
# ruff: noqa: D205, D209
4+
5+
36
import pytest
47

8+
from tests.utils import generate_sample_records
59
from timdex_dataset_api import TIMDEXDataset
610

711

@@ -22,3 +26,37 @@ def local_dataset_location():
2226
@pytest.fixture
2327
def local_dataset(local_dataset_location):
2428
return TIMDEXDataset.load(local_dataset_location)
29+
30+
31+
@pytest.fixture
32+
def new_temp_dataset(tmp_path) -> TIMDEXDataset:
33+
location = str(tmp_path / "new_dataset")
34+
return TIMDEXDataset(location=location)
35+
36+
37+
@pytest.fixture
38+
def small_records_iter():
39+
"""Simulates an iterator of X number of valid DatasetRecord instances."""
40+
41+
def _records_iter(num_records):
42+
return generate_sample_records(num_records)
43+
44+
return _records_iter
45+
46+
47+
@pytest.fixture
48+
def small_records_iter_without_partitions():
49+
"""Simulates an iterator of X number of DatasetRecord instances WITHOUT partition
50+
values included."""
51+
52+
def _records_iter(num_records):
53+
return generate_sample_records(
54+
num_records,
55+
source=None,
56+
run_date=None,
57+
run_type=None,
58+
action=None,
59+
run_id=None,
60+
)
61+
62+
return _records_iter

tests/test_dataset.py

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
# ruff: noqa: S105, S106, SLF001
2-
32
from unittest.mock import MagicMock, patch
43

54
import pyarrow as pa
@@ -33,7 +32,7 @@
3332
],
3433
)
3534
@patch("timdex_dataset_api.dataset.TIMDEXDataset.get_s3_filesystem")
36-
def test_parse_location_single_local_directory(
35+
def test_parse_location_success_scenarios(
3736
get_s3_filesystem,
3837
location,
3938
expected_filesystem,
@@ -45,6 +44,28 @@ def test_parse_location_single_local_directory(
4544
assert source == expected_source
4645

4746

47+
@pytest.mark.parametrize(
48+
("location", "expected_exception"),
49+
[
50+
# None is invalid location type
51+
(None, TypeError),
52+
# mixed local and S3 locations
53+
(
54+
[
55+
"/local/path/to/dataset/records.parquet",
56+
"s3://path/to/dataset/records.parquet",
57+
],
58+
ValueError,
59+
),
60+
],
61+
)
62+
@patch("timdex_dataset_api.dataset.TIMDEXDataset.get_s3_filesystem")
63+
def test_parse_location_error_scenarios(get_s3_filesystem, location, expected_exception):
64+
get_s3_filesystem.return_value = fs.S3FileSystem()
65+
with pytest.raises(expected_exception):
66+
_ = TIMDEXDataset.parse_location(location)
67+
68+
4869
def test_get_s3_filesystem_success(mocker):
4970
mocked_s3_filesystem = mocker.spy(fs, "S3FileSystem")
5071
s3_filesystem = TIMDEXDataset.get_s3_filesystem()

tests/test_dataset_write.py

Lines changed: 181 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,181 @@
1+
# ruff: noqa: S105, S106, SLF001, PLR2004, PD901, D209, D205
2+
3+
import datetime
4+
import math
5+
import os
6+
7+
import pyarrow.dataset as ds
8+
import pytest
9+
10+
from timdex_dataset_api.dataset import (
11+
MAX_ROWS_PER_FILE,
12+
TIMDEX_DATASET_SCHEMA,
13+
DatasetNotLoadedError,
14+
TIMDEXDataset,
15+
)
16+
from timdex_dataset_api.record import DatasetRecord
17+
18+
19+
def test_dataset_record_serialization():
20+
dataset_record = DatasetRecord(
21+
timdex_record_id="alma:123",
22+
source_record=b"<record><title>Hello World.</title></record>",
23+
transformed_record=b"""{"title":["Hello World."]}""",
24+
)
25+
assert dataset_record.to_dict() == {
26+
"timdex_record_id": "alma:123",
27+
"source_record": b"<record><title>Hello World.</title></record>",
28+
"transformed_record": b"""{"title":["Hello World."]}""",
29+
"source": None,
30+
"run_date": None,
31+
"run_type": None,
32+
"action": None,
33+
"run_id": None,
34+
}
35+
36+
37+
def test_dataset_record_serialization_with_partition_values_provided():
38+
dataset_record = DatasetRecord(
39+
timdex_record_id="alma:123",
40+
source_record=b"<record><title>Hello World.</title></record>",
41+
transformed_record=b"""{"title":["Hello World."]}""",
42+
)
43+
partition_values = {
44+
"source": "alma",
45+
"run_date": "2024-12-01",
46+
"run_type": "daily",
47+
"action": "index",
48+
"run_id": "000-111-aaa-bbb",
49+
}
50+
assert dataset_record.to_dict(partition_values=partition_values) == {
51+
"timdex_record_id": "alma:123",
52+
"source_record": b"<record><title>Hello World.</title></record>",
53+
"transformed_record": b"""{"title":["Hello World."]}""",
54+
"source": "alma",
55+
"run_date": "2024-12-01",
56+
"run_type": "daily",
57+
"action": "index",
58+
"run_id": "000-111-aaa-bbb",
59+
}
60+
61+
62+
def test_dataset_write_records_to_new_dataset(new_temp_dataset, small_records_iter):
63+
files_written = new_temp_dataset.write(small_records_iter(10_000))
64+
assert len(files_written) == 1
65+
assert os.path.exists(new_temp_dataset.location)
66+
67+
# load newly created dataset as new TIMDEXDataset instance
68+
dataset = TIMDEXDataset.load(new_temp_dataset.location)
69+
assert dataset.row_count == 10_000
70+
71+
72+
def test_dataset_reload_after_write(new_temp_dataset, small_records_iter):
73+
files_written = new_temp_dataset.write(small_records_iter(10_000))
74+
assert len(files_written) == 1
75+
assert os.path.exists(new_temp_dataset.location)
76+
77+
# attempt row count before reload
78+
with pytest.raises(DatasetNotLoadedError):
79+
_ = new_temp_dataset.row_count
80+
81+
# attempt row count after reload
82+
new_temp_dataset.reload()
83+
assert new_temp_dataset.row_count == 10_000
84+
85+
86+
def test_dataset_write_default_max_rows_per_file(new_temp_dataset, small_records_iter):
87+
"""Default is 100k rows per file, therefore writing 200,033 records should result in
88+
3 files (x2 @ 100k rows, x1 @ 33 rows)."""
89+
total_records = 200_033
90+
91+
new_temp_dataset.write(small_records_iter(total_records))
92+
new_temp_dataset.reload()
93+
94+
assert new_temp_dataset.row_count == total_records
95+
assert len(new_temp_dataset.dataset.files) == math.ceil(
96+
total_records / MAX_ROWS_PER_FILE
97+
)
98+
99+
100+
def test_dataset_write_record_batches_uses_batch_size(
101+
new_temp_dataset, small_records_iter
102+
):
103+
total_records = 101
104+
batch_size = 50
105+
batches = list(
106+
new_temp_dataset.get_dataset_record_batches(
107+
small_records_iter(total_records), batch_size=batch_size
108+
)
109+
)
110+
assert len(batches) == math.ceil(total_records / batch_size)
111+
112+
113+
def test_dataset_write_to_multiple_locations_raise_error(small_records_iter):
114+
timdex_dataset = TIMDEXDataset(
115+
location=["/path/to/records-1.parquet", "/path/to/records-2.parquet"]
116+
)
117+
with pytest.raises(
118+
TypeError,
119+
match="Dataset location must be the root of a single dataset for writing",
120+
):
121+
timdex_dataset.write(small_records_iter(10))
122+
123+
124+
def test_dataset_write_mixin_partition_values_used(
125+
new_temp_dataset, small_records_iter_without_partitions
126+
):
127+
partition_values = {
128+
"source": "alma",
129+
"run_date": "2024-12-01",
130+
"run_type": "daily",
131+
"action": "index",
132+
"run_id": "000-111-aaa-bbb",
133+
}
134+
_written_files = new_temp_dataset.write(
135+
small_records_iter_without_partitions(10),
136+
partition_values=partition_values,
137+
)
138+
new_temp_dataset.reload()
139+
140+
# load as pandas dataframe and assert column values
141+
df = new_temp_dataset.dataset.to_table().to_pandas()
142+
row = df.iloc[0]
143+
assert row.source == partition_values["source"]
144+
assert row.run_date == datetime.date(2024, 12, 1)
145+
assert row.run_type == partition_values["run_type"]
146+
assert row.action == partition_values["action"]
147+
assert row.action == partition_values["action"]
148+
149+
150+
def test_dataset_write_schema_partitions_correctly_ordered(
151+
new_temp_dataset, small_records_iter
152+
):
153+
written_files = new_temp_dataset.write(
154+
small_records_iter(10),
155+
partition_values={
156+
"source": "alma",
157+
"run_date": "2024-12-01",
158+
"run_type": "daily",
159+
"action": "index",
160+
"run_id": "000-111-aaa-bbb",
161+
},
162+
)
163+
file = written_files[0]
164+
assert (
165+
"/source=alma/run_date=2024-12-01/run_type=daily"
166+
"/action=index/run_id=000-111-aaa-bbb" in file.path
167+
)
168+
169+
170+
def test_dataset_write_schema_applied_to_dataset(new_temp_dataset, small_records_iter):
171+
new_temp_dataset.write(small_records_iter(10))
172+
173+
# manually load dataset to confirm schema without TIMDEXDataset projecting schema
174+
# during load
175+
dataset = ds.dataset(
176+
new_temp_dataset.location,
177+
format="parquet",
178+
partitioning="hive",
179+
)
180+
181+
assert set(dataset.schema.names) == set(TIMDEX_DATASET_SCHEMA.names)

tests/utils.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
"""tests/utils.py"""
2+
3+
# ruff: noqa: S311
4+
5+
import random
6+
import uuid
7+
from collections.abc import Iterator
8+
9+
from timdex_dataset_api import DatasetRecord
10+
11+
12+
def generate_sample_records(
13+
num_records: int,
14+
timdex_record_id_prefix: str = "alma",
15+
source: str | None = "alma",
16+
run_date: str | None = "2024-12-01",
17+
run_type: str | None = "daily",
18+
action: str | None = "index",
19+
run_id: str | None = None,
20+
) -> Iterator[DatasetRecord]:
21+
"""Generate sample DatasetRecords."""
22+
if not run_id:
23+
run_id = str(uuid.uuid4())
24+
25+
for x in range(num_records):
26+
yield DatasetRecord(
27+
timdex_record_id=f"{timdex_record_id_prefix}:{x}",
28+
source_record=b"<record><title>Hello World.</title></record>",
29+
transformed_record=b"""{"title":["Hello World."]}""",
30+
source=source,
31+
run_date=run_date,
32+
run_type=run_type,
33+
action=action,
34+
run_id=run_id,
35+
)
36+
37+
38+
def generate_sample_records_with_simulated_partitions(
39+
num_records: int, num_run_ids: int = 4
40+
) -> Iterator[DatasetRecord]:
41+
"""Generate sample DatasetRecords, with simulated sampling of partitions."""
42+
sources = ["alma", "dspsace", "aspace", "libguides", "gismit", "gisogm"]
43+
run_dates = ["2024-01-01", "2024-06-15", "2024-12-31"]
44+
run_types = ["full", "daily"]
45+
actions = ["index", "delete"]
46+
run_ids = [str(uuid.uuid4()) for x in range(num_run_ids)]
47+
48+
records_remaining = num_records
49+
while records_remaining > 0:
50+
batch_size = random.randint(1, min(100, records_remaining))
51+
yield from generate_sample_records(
52+
num_records=batch_size,
53+
timdex_record_id_prefix=random.choice(sources),
54+
source=random.choice(sources),
55+
run_date=random.choice(run_dates),
56+
run_type=random.choice(run_types),
57+
action=random.choice(actions),
58+
run_id=random.choice(run_ids),
59+
)
60+
records_remaining -= batch_size

timdex_dataset_api/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
"""timdex_dataset_api/__init__.py"""
22

33
from timdex_dataset_api.dataset import TIMDEXDataset
4+
from timdex_dataset_api.record import DatasetRecord
45

56
__version__ = "0.1.0"
67

78
__all__ = [
9+
"DatasetRecord",
810
"TIMDEXDataset",
911
]

0 commit comments

Comments
 (0)