Skip to content

Commit b9919b3

Browse files
authored
INTPYTHON-520 Ensure all parquet data types are handled (#338)
1 parent c472854 commit b9919b3

File tree

2 files changed

+155
-4
lines changed

2 files changed

+155
-4
lines changed

bindings/python/pymongoarrow/api.py

Lines changed: 32 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,20 @@
3535
from numpy import ndarray
3636
from pyarrow import Schema as ArrowSchema
3737
from pyarrow import Table, timestamp
38-
from pyarrow.types import is_date32, is_date64
38+
from pyarrow.types import (
39+
is_date32,
40+
is_date64,
41+
is_duration,
42+
is_float16,
43+
is_float32,
44+
is_int8,
45+
is_int16,
46+
is_list,
47+
is_uint8,
48+
is_uint16,
49+
is_uint32,
50+
is_uint64,
51+
)
3952
from pymongo.common import MAX_WRITE_BATCH_SIZE
4053

4154
from pymongoarrow.context import PyMongoArrowContext
@@ -475,14 +488,15 @@ def transform_python(self, value):
475488
return Decimal128(value)
476489

477490

478-
def write(collection, tabular, *, exclude_none: bool = False):
491+
def write(collection, tabular, *, exclude_none: bool = False, auto_convert: bool = True):
479492
"""Write data from `tabular` into the given MongoDB `collection`.
480493
481494
:Parameters:
482495
- `collection`: Instance of :class:`~pymongo.collection.Collection`.
483496
against which to run the operation.
484497
- `tabular`: A tabular data store to use for the write operation.
485498
- `exclude_none`: Whether to skip writing `null` fields in documents.
499+
- `auto_convert` (optional): Whether to attempt a best-effort conversion of unsupported types.
486500
487501
:Returns:
488502
An instance of :class:`result.ArrowWriteResult`.
@@ -500,9 +514,24 @@ def write(collection, tabular, *, exclude_none: bool = False):
500514
if is_date32(dtype) or is_date64(dtype):
501515
changed = True
502516
dtype = timestamp("ms") # noqa: PLW2901
517+
elif auto_convert:
518+
if is_uint8(dtype) or is_uint16(dtype) or is_int8(dtype) or is_int16(dtype):
519+
changed = True
520+
dtype = pa.int32() # noqa: PLW2901
521+
elif is_uint32(dtype) or is_uint64(dtype) or is_duration(dtype):
522+
changed = True
523+
dtype = pa.int64() # noqa: PLW2901
524+
elif is_float16(dtype) or is_float32(dtype):
525+
changed = True
526+
dtype = pa.float64() # noqa: PLW2901
503527
new_types.append(dtype)
504528
if changed:
505-
cols = [tabular.column(i).cast(new_types[i]) for i in range(tabular.num_columns)]
529+
cols = [
530+
tabular.column(i).cast(new_types[i])
531+
if not is_list(new_types[i])
532+
else tabular.column(i)
533+
for i in range(tabular.num_columns)
534+
]
506535
tabular = Table.from_arrays(cols, names=tabular.column_names)
507536
_validate_schema(tabular.schema.types)
508537
elif pd is not None and isinstance(tabular, pd.DataFrame):

bindings/python/test/test_arrow.py

Lines changed: 123 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,14 +19,15 @@
1919
import threading
2020
import unittest
2121
import unittest.mock as mock
22-
from datetime import date, datetime
22+
from datetime import date, datetime, timedelta
2323
from pathlib import Path
2424
from test import client_context
2525
from test.utils import AllowListEventListener, NullsTestMixin
2626

2727
import pyarrow as pa
2828
import pyarrow.json
2929
import pymongo
30+
import pytest
3031
from bson import Binary, Code, CodecOptions, Decimal128, ObjectId, json_util
3132
from pyarrow import (
3233
Table,
@@ -63,6 +64,11 @@
6364
ObjectIdType,
6465
)
6566

67+
try:
68+
import pandas as pd
69+
except ImportError:
70+
pd = None
71+
6672
HERE = Path(__file__).absolute().parent
6773

6874

@@ -1082,6 +1088,122 @@ def test_decimal128(self):
10821088
coll_data = list(self.coll.find({}))
10831089
assert coll_data[0]["data"] == Decimal128(a)
10841090

1091+
def alltypes_sample(self, size=10000, seed=0, categorical=False):
1092+
# modified from https://github.com/apache/arrow/blob/main/python/pyarrow/tests/parquet/common.py#L139
1093+
import numpy as np
1094+
import pandas as pd
1095+
1096+
np.random.seed(seed)
1097+
arrays = {
1098+
"uint8": np.arange(size, dtype=np.uint8),
1099+
"uint16": np.arange(size, dtype=np.uint16),
1100+
"uint32": np.arange(size, dtype=np.uint32),
1101+
"uint64": np.arange(size, dtype=np.uint64),
1102+
"int8": np.arange(size, dtype=np.int8),
1103+
"int16": np.arange(size, dtype=np.int16),
1104+
"int32": np.arange(size, dtype=np.int32),
1105+
"int64": np.arange(size, dtype=np.int64),
1106+
"float16": np.arange(size, dtype=np.float16),
1107+
"float32": np.arange(size, dtype=np.float32),
1108+
"float64": np.arange(size, dtype=np.float64),
1109+
"bool": np.random.randn(size) > 0,
1110+
"datetime_ms": np.arange("2016-01-01T00:00:00.001", size, dtype="datetime64[ms]"),
1111+
"datetime_us": np.arange("2016-01-01T00:00:00.000001", size, dtype="datetime64[us]"),
1112+
"datetime_ns": np.arange("2016-01-01T00:00:00.000000001", size, dtype="datetime64[ns]"),
1113+
"timedelta": np.arange(size, dtype="timedelta64[s]"),
1114+
"str": pd.Series([str(x) for x in range(size)]),
1115+
"empty_str": [""] * size,
1116+
"str_with_nulls": [None] + [str(x) for x in range(size - 2)] + [None],
1117+
"null": [None] * size,
1118+
"null_list": [None] * 2 + [[None] * (x % 4) for x in range(size - 2)],
1119+
}
1120+
if categorical:
1121+
arrays["str_category"] = arrays["str"].astype("category")
1122+
return pd.DataFrame(arrays)
1123+
1124+
def convert_categorical_columns_to_string(self, table):
1125+
"""
1126+
Converts any categorical columns in an Arrow Table into string columns.
1127+
This preprocessing step ensures compatibility with PyMongoArrow schema validation.
1128+
"""
1129+
new_columns = []
1130+
for column_name, column in zip(table.column_names, table.columns):
1131+
if pa.types.is_dictionary(column.type):
1132+
# Convert dictionary (categorical) columns to string
1133+
new_columns.append(pa.array(column.combine_chunks().to_pandas(), type=pa.string()))
1134+
else:
1135+
# Keep other column types intact
1136+
new_columns.append(column)
1137+
# Return a new Arrow Table
1138+
return pa.Table.from_arrays(new_columns, names=table.column_names)
1139+
1140+
def compare_arrow_mongodb_data(self, arrow_table, mongo_data):
1141+
"""
1142+
Compare data types and precision between an Arrow Table and MongoDB documents.
1143+
1144+
Params:
1145+
arrow_table (pyarrow.Table): The original Arrow Table before insertion.
1146+
mongo_data (list): The list of MongoDB documents fetched from the collection.
1147+
1148+
Raises:
1149+
AssertionError: If any data type or value doesn't match between Arrow and MongoDB.
1150+
"""
1151+
import decimal
1152+
1153+
import numpy as np
1154+
1155+
assert len(mongo_data) == len(arrow_table), "MongoDB data length mismatch with Arrow Table."
1156+
1157+
# Convert Arrow Table to Python dict format for comparison
1158+
arrow_dict = arrow_table.to_pydict()
1159+
1160+
for row_idx in range(arrow_table.num_rows):
1161+
mongo_document = mongo_data[row_idx] # Fetch the corresponding MongoDB document
1162+
1163+
for column_name in arrow_table.column_names:
1164+
arrow_value = arrow_dict[column_name][
1165+
row_idx
1166+
] # Get the value from the Arrow Table (Python representation)
1167+
mongo_value = mongo_document.get(column_name) # Get the value from MongoDB
1168+
1169+
if isinstance(arrow_value, decimal.Decimal):
1170+
assert (
1171+
Decimal128(arrow_value).to_decimal() == Decimal128(mongo_value).to_decimal()
1172+
), f"Precision loss in decimal field '{column_name}' for row {row_idx}. Expected {arrow_value}, got {mongo_value}."
1173+
elif isinstance(arrow_value, (np.datetime64, pd.Timestamp, datetime)):
1174+
arrow_value_rounded = pd.Timestamp(arrow_value).round(
1175+
"ms"
1176+
) # Round to milliseconds
1177+
assert (
1178+
arrow_value_rounded.to_pydatetime() == mongo_value
1179+
), f"Datetime mismatch in field '{column_name}' for row {row_idx}. Expected {arrow_value_rounded}, got {mongo_value}."
1180+
elif isinstance(arrow_value, (list, np.ndarray)):
1181+
assert (
1182+
arrow_value == mongo_value
1183+
), f"List mismatch in field '{column_name}' for row {row_idx}. Expected {arrow_value}, got {mongo_value}."
1184+
elif isinstance(arrow_value, timedelta):
1185+
assert (
1186+
arrow_value == timedelta(seconds=mongo_value)
1187+
), f"Timedelta mismatch in field '{column_name}' for row {row_idx}. Expected {arrow_value}, got {mongo_value}."
1188+
else:
1189+
assert (
1190+
arrow_value == mongo_value
1191+
), f"Value mismatch in field '{column_name}' for row {row_idx}. Expected {arrow_value}, got {mongo_value}."
1192+
1193+
def test_all_types(self):
1194+
"""
1195+
Test the conversion of all standard data types from Parquet → PyArrow → Python → BSON.
1196+
"""
1197+
if pd is None:
1198+
pytest.skip("Requires pandas.", allow_module_level=True)
1199+
df = self.alltypes_sample(size=100, seed=42, categorical=True)
1200+
arrow_table = pa.Table.from_pandas(df)
1201+
arrow_table = self.convert_categorical_columns_to_string(arrow_table)
1202+
self.coll.drop()
1203+
write(self.coll, arrow_table)
1204+
coll_data = list(self.coll.find({}))
1205+
self.compare_arrow_mongodb_data(arrow_table, coll_data)
1206+
10851207
def test_empty_embedded_array(self):
10861208
# From INTPYTHON-575.
10871209
self.coll.drop()

0 commit comments

Comments
 (0)