Skip to content

INTPYTHON-520 Ensure all parquet data types are handled #338

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Jul 30, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 32 additions & 3 deletions bindings/python/pymongoarrow/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,20 @@
from numpy import ndarray
from pyarrow import Schema as ArrowSchema
from pyarrow import Table, timestamp
from pyarrow.types import is_date32, is_date64
from pyarrow.types import (
is_date32,
is_date64,
is_duration,
is_float16,
is_float32,
is_int8,
is_int16,
is_list,
is_uint8,
is_uint16,
is_uint32,
is_uint64,
)
from pymongo.common import MAX_WRITE_BATCH_SIZE

from pymongoarrow.context import PyMongoArrowContext
Expand Down Expand Up @@ -475,14 +488,15 @@ def transform_python(self, value):
return Decimal128(value)


def write(collection, tabular, *, exclude_none: bool = False):
def write(collection, tabular, *, exclude_none: bool = False, auto_convert: bool = True):
"""Write data from `tabular` into the given MongoDB `collection`.

:Parameters:
- `collection`: Instance of :class:`~pymongo.collection.Collection`.
against which to run the operation.
- `tabular`: A tabular data store to use for the write operation.
- `exclude_none`: Whether to skip writing `null` fields in documents.
- `auto_convert` (optional): Whether to attempt a best-effort conversion of unsupported types.

:Returns:
An instance of :class:`result.ArrowWriteResult`.
Expand All @@ -500,9 +514,24 @@ def write(collection, tabular, *, exclude_none: bool = False):
if is_date32(dtype) or is_date64(dtype):
changed = True
dtype = timestamp("ms") # noqa: PLW2901
elif auto_convert:
if is_uint8(dtype) or is_uint16(dtype) or is_int8(dtype) or is_int16(dtype):
changed = True
dtype = pa.int32() # noqa: PLW2901
elif is_uint32(dtype) or is_uint64(dtype) or is_duration(dtype):
changed = True
dtype = pa.int64() # noqa: PLW2901
elif is_float16(dtype) or is_float32(dtype):
changed = True
dtype = pa.float64() # noqa: PLW2901
new_types.append(dtype)
if changed:
cols = [tabular.column(i).cast(new_types[i]) for i in range(tabular.num_columns)]
cols = [
tabular.column(i).cast(new_types[i])
if not is_list(new_types[i])
else tabular.column(i)
for i in range(tabular.num_columns)
]
tabular = Table.from_arrays(cols, names=tabular.column_names)
_validate_schema(tabular.schema.types)
elif pd is not None and isinstance(tabular, pd.DataFrame):
Expand Down
124 changes: 123 additions & 1 deletion bindings/python/test/test_arrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,15 @@
import threading
import unittest
import unittest.mock as mock
from datetime import date, datetime
from datetime import date, datetime, timedelta
from pathlib import Path
from test import client_context
from test.utils import AllowListEventListener, NullsTestMixin

import pyarrow as pa
import pyarrow.json
import pymongo
import pytest
from bson import Binary, Code, CodecOptions, Decimal128, ObjectId, json_util
from pyarrow import (
Table,
Expand Down Expand Up @@ -63,6 +64,11 @@
ObjectIdType,
)

try:
import pandas as pd
except ImportError:
pd = None

HERE = Path(__file__).absolute().parent


Expand Down Expand Up @@ -1082,6 +1088,122 @@ def test_decimal128(self):
coll_data = list(self.coll.find({}))
assert coll_data[0]["data"] == Decimal128(a)

def alltypes_sample(self, size=10000, seed=0, categorical=False):
# modified from https://github.com/apache/arrow/blob/main/python/pyarrow/tests/parquet/common.py#L139
import numpy as np
import pandas as pd

np.random.seed(seed)
arrays = {
"uint8": np.arange(size, dtype=np.uint8),
"uint16": np.arange(size, dtype=np.uint16),
"uint32": np.arange(size, dtype=np.uint32),
"uint64": np.arange(size, dtype=np.uint64),
"int8": np.arange(size, dtype=np.int8),
"int16": np.arange(size, dtype=np.int16),
"int32": np.arange(size, dtype=np.int32),
"int64": np.arange(size, dtype=np.int64),
"float16": np.arange(size, dtype=np.float16),
"float32": np.arange(size, dtype=np.float32),
"float64": np.arange(size, dtype=np.float64),
"bool": np.random.randn(size) > 0,
"datetime_ms": np.arange("2016-01-01T00:00:00.001", size, dtype="datetime64[ms]"),
"datetime_us": np.arange("2016-01-01T00:00:00.000001", size, dtype="datetime64[us]"),
"datetime_ns": np.arange("2016-01-01T00:00:00.000000001", size, dtype="datetime64[ns]"),
"timedelta": np.arange(size, dtype="timedelta64[s]"),
"str": pd.Series([str(x) for x in range(size)]),
"empty_str": [""] * size,
"str_with_nulls": [None] + [str(x) for x in range(size - 2)] + [None],
"null": [None] * size,
"null_list": [None] * 2 + [[None] * (x % 4) for x in range(size - 2)],
}
if categorical:
arrays["str_category"] = arrays["str"].astype("category")
return pd.DataFrame(arrays)

def convert_categorical_columns_to_string(self, table):
"""
Converts any categorical columns in an Arrow Table into string columns.
This preprocessing step ensures compatibility with PyMongoArrow schema validation.
"""
new_columns = []
for column_name, column in zip(table.column_names, table.columns):
if pa.types.is_dictionary(column.type):
# Convert dictionary (categorical) columns to string
new_columns.append(pa.array(column.combine_chunks().to_pandas(), type=pa.string()))
else:
# Keep other column types intact
new_columns.append(column)
# Return a new Arrow Table
return pa.Table.from_arrays(new_columns, names=table.column_names)

def compare_arrow_mongodb_data(self, arrow_table, mongo_data):
"""
Compare data types and precision between an Arrow Table and MongoDB documents.

Params:
arrow_table (pyarrow.Table): The original Arrow Table before insertion.
mongo_data (list): The list of MongoDB documents fetched from the collection.

Raises:
AssertionError: If any data type or value doesn't match between Arrow and MongoDB.
"""
import decimal

import numpy as np

assert len(mongo_data) == len(arrow_table), "MongoDB data length mismatch with Arrow Table."

# Convert Arrow Table to Python dict format for comparison
arrow_dict = arrow_table.to_pydict()

for row_idx in range(arrow_table.num_rows):
mongo_document = mongo_data[row_idx] # Fetch the corresponding MongoDB document

for column_name in arrow_table.column_names:
arrow_value = arrow_dict[column_name][
row_idx
] # Get the value from the Arrow Table (Python representation)
mongo_value = mongo_document.get(column_name) # Get the value from MongoDB

if isinstance(arrow_value, decimal.Decimal):
assert (
Decimal128(arrow_value).to_decimal() == Decimal128(mongo_value).to_decimal()
), f"Precision loss in decimal field '{column_name}' for row {row_idx}. Expected {arrow_value}, got {mongo_value}."
elif isinstance(arrow_value, (np.datetime64, pd.Timestamp, datetime)):
arrow_value_rounded = pd.Timestamp(arrow_value).round(
"ms"
) # Round to milliseconds
assert (
arrow_value_rounded.to_pydatetime() == mongo_value
), f"Datetime mismatch in field '{column_name}' for row {row_idx}. Expected {arrow_value_rounded}, got {mongo_value}."
elif isinstance(arrow_value, (list, np.ndarray)):
assert (
arrow_value == mongo_value
), f"List mismatch in field '{column_name}' for row {row_idx}. Expected {arrow_value}, got {mongo_value}."
elif isinstance(arrow_value, timedelta):
assert (
arrow_value == timedelta(seconds=mongo_value)
), f"Timedelta mismatch in field '{column_name}' for row {row_idx}. Expected {arrow_value}, got {mongo_value}."
else:
assert (
arrow_value == mongo_value
), f"Value mismatch in field '{column_name}' for row {row_idx}. Expected {arrow_value}, got {mongo_value}."

def test_all_types(self):
"""
Test the conversion of all standard data types from Parquet → PyArrow → Python → BSON.
"""
if pd is None:
pytest.skip("Requires pandas.", allow_module_level=True)
df = self.alltypes_sample(size=100, seed=42, categorical=True)
arrow_table = pa.Table.from_pandas(df)
arrow_table = self.convert_categorical_columns_to_string(arrow_table)
self.coll.drop()
write(self.coll, arrow_table)
coll_data = list(self.coll.find({}))
self.compare_arrow_mongodb_data(arrow_table, coll_data)

def test_empty_embedded_array(self):
# From INTPYTHON-575.
self.coll.drop()
Expand Down
Loading