diff --git a/bindings/python/pymongoarrow/api.py b/bindings/python/pymongoarrow/api.py index 60b32d29..ce53bd5e 100644 --- a/bindings/python/pymongoarrow/api.py +++ b/bindings/python/pymongoarrow/api.py @@ -35,7 +35,20 @@ from numpy import ndarray from pyarrow import Schema as ArrowSchema from pyarrow import Table, timestamp -from pyarrow.types import is_date32, is_date64 +from pyarrow.types import ( + is_date32, + is_date64, + is_duration, + is_float16, + is_float32, + is_int8, + is_int16, + is_list, + is_uint8, + is_uint16, + is_uint32, + is_uint64, +) from pymongo.common import MAX_WRITE_BATCH_SIZE from pymongoarrow.context import PyMongoArrowContext @@ -475,7 +488,7 @@ def transform_python(self, value): return Decimal128(value) -def write(collection, tabular, *, exclude_none: bool = False): +def write(collection, tabular, *, exclude_none: bool = False, auto_convert: bool = True): """Write data from `tabular` into the given MongoDB `collection`. :Parameters: @@ -483,6 +496,7 @@ def write(collection, tabular, *, exclude_none: bool = False): against which to run the operation. - `tabular`: A tabular data store to use for the write operation. - `exclude_none`: Whether to skip writing `null` fields in documents. + - `auto_convert` (optional): Whether to attempt a best-effort conversion of unsupported types. :Returns: An instance of :class:`result.ArrowWriteResult`. @@ -500,9 +514,24 @@ def write(collection, tabular, *, exclude_none: bool = False): if is_date32(dtype) or is_date64(dtype): changed = True dtype = timestamp("ms") # noqa: PLW2901 + elif auto_convert: + if is_uint8(dtype) or is_uint16(dtype) or is_int8(dtype) or is_int16(dtype): + changed = True + dtype = pa.int32() # noqa: PLW2901 + elif is_uint32(dtype) or is_uint64(dtype) or is_duration(dtype): + changed = True + dtype = pa.int64() # noqa: PLW2901 + elif is_float16(dtype) or is_float32(dtype): + changed = True + dtype = pa.float64() # noqa: PLW2901 new_types.append(dtype) if changed: - cols = [tabular.column(i).cast(new_types[i]) for i in range(tabular.num_columns)] + cols = [ + tabular.column(i).cast(new_types[i]) + if not is_list(new_types[i]) + else tabular.column(i) + for i in range(tabular.num_columns) + ] tabular = Table.from_arrays(cols, names=tabular.column_names) _validate_schema(tabular.schema.types) elif pd is not None and isinstance(tabular, pd.DataFrame): diff --git a/bindings/python/test/test_arrow.py b/bindings/python/test/test_arrow.py index e0f16b06..9e138b1b 100644 --- a/bindings/python/test/test_arrow.py +++ b/bindings/python/test/test_arrow.py @@ -19,7 +19,7 @@ import threading import unittest import unittest.mock as mock -from datetime import date, datetime +from datetime import date, datetime, timedelta from pathlib import Path from test import client_context from test.utils import AllowListEventListener, NullsTestMixin @@ -27,6 +27,7 @@ import pyarrow as pa import pyarrow.json import pymongo +import pytest from bson import Binary, Code, CodecOptions, Decimal128, ObjectId, json_util from pyarrow import ( Table, @@ -63,6 +64,11 @@ ObjectIdType, ) +try: + import pandas as pd +except ImportError: + pd = None + HERE = Path(__file__).absolute().parent @@ -1082,6 +1088,122 @@ def test_decimal128(self): coll_data = list(self.coll.find({})) assert coll_data[0]["data"] == Decimal128(a) + def alltypes_sample(self, size=10000, seed=0, categorical=False): + # modified from https://github.com/apache/arrow/blob/main/python/pyarrow/tests/parquet/common.py#L139 + import numpy as np + import pandas as pd + + np.random.seed(seed) + arrays = { + "uint8": np.arange(size, dtype=np.uint8), + "uint16": np.arange(size, dtype=np.uint16), + "uint32": np.arange(size, dtype=np.uint32), + "uint64": np.arange(size, dtype=np.uint64), + "int8": np.arange(size, dtype=np.int8), + "int16": np.arange(size, dtype=np.int16), + "int32": np.arange(size, dtype=np.int32), + "int64": np.arange(size, dtype=np.int64), + "float16": np.arange(size, dtype=np.float16), + "float32": np.arange(size, dtype=np.float32), + "float64": np.arange(size, dtype=np.float64), + "bool": np.random.randn(size) > 0, + "datetime_ms": np.arange("2016-01-01T00:00:00.001", size, dtype="datetime64[ms]"), + "datetime_us": np.arange("2016-01-01T00:00:00.000001", size, dtype="datetime64[us]"), + "datetime_ns": np.arange("2016-01-01T00:00:00.000000001", size, dtype="datetime64[ns]"), + "timedelta": np.arange(size, dtype="timedelta64[s]"), + "str": pd.Series([str(x) for x in range(size)]), + "empty_str": [""] * size, + "str_with_nulls": [None] + [str(x) for x in range(size - 2)] + [None], + "null": [None] * size, + "null_list": [None] * 2 + [[None] * (x % 4) for x in range(size - 2)], + } + if categorical: + arrays["str_category"] = arrays["str"].astype("category") + return pd.DataFrame(arrays) + + def convert_categorical_columns_to_string(self, table): + """ + Converts any categorical columns in an Arrow Table into string columns. + This preprocessing step ensures compatibility with PyMongoArrow schema validation. + """ + new_columns = [] + for column_name, column in zip(table.column_names, table.columns): + if pa.types.is_dictionary(column.type): + # Convert dictionary (categorical) columns to string + new_columns.append(pa.array(column.combine_chunks().to_pandas(), type=pa.string())) + else: + # Keep other column types intact + new_columns.append(column) + # Return a new Arrow Table + return pa.Table.from_arrays(new_columns, names=table.column_names) + + def compare_arrow_mongodb_data(self, arrow_table, mongo_data): + """ + Compare data types and precision between an Arrow Table and MongoDB documents. + + Params: + arrow_table (pyarrow.Table): The original Arrow Table before insertion. + mongo_data (list): The list of MongoDB documents fetched from the collection. + + Raises: + AssertionError: If any data type or value doesn't match between Arrow and MongoDB. + """ + import decimal + + import numpy as np + + assert len(mongo_data) == len(arrow_table), "MongoDB data length mismatch with Arrow Table." + + # Convert Arrow Table to Python dict format for comparison + arrow_dict = arrow_table.to_pydict() + + for row_idx in range(arrow_table.num_rows): + mongo_document = mongo_data[row_idx] # Fetch the corresponding MongoDB document + + for column_name in arrow_table.column_names: + arrow_value = arrow_dict[column_name][ + row_idx + ] # Get the value from the Arrow Table (Python representation) + mongo_value = mongo_document.get(column_name) # Get the value from MongoDB + + if isinstance(arrow_value, decimal.Decimal): + assert ( + Decimal128(arrow_value).to_decimal() == Decimal128(mongo_value).to_decimal() + ), f"Precision loss in decimal field '{column_name}' for row {row_idx}. Expected {arrow_value}, got {mongo_value}." + elif isinstance(arrow_value, (np.datetime64, pd.Timestamp, datetime)): + arrow_value_rounded = pd.Timestamp(arrow_value).round( + "ms" + ) # Round to milliseconds + assert ( + arrow_value_rounded.to_pydatetime() == mongo_value + ), f"Datetime mismatch in field '{column_name}' for row {row_idx}. Expected {arrow_value_rounded}, got {mongo_value}." + elif isinstance(arrow_value, (list, np.ndarray)): + assert ( + arrow_value == mongo_value + ), f"List mismatch in field '{column_name}' for row {row_idx}. Expected {arrow_value}, got {mongo_value}." + elif isinstance(arrow_value, timedelta): + assert ( + arrow_value == timedelta(seconds=mongo_value) + ), f"Timedelta mismatch in field '{column_name}' for row {row_idx}. Expected {arrow_value}, got {mongo_value}." + else: + assert ( + arrow_value == mongo_value + ), f"Value mismatch in field '{column_name}' for row {row_idx}. Expected {arrow_value}, got {mongo_value}." + + def test_all_types(self): + """ + Test the conversion of all standard data types from Parquet → PyArrow → Python → BSON. + """ + if pd is None: + pytest.skip("Requires pandas.", allow_module_level=True) + df = self.alltypes_sample(size=100, seed=42, categorical=True) + arrow_table = pa.Table.from_pandas(df) + arrow_table = self.convert_categorical_columns_to_string(arrow_table) + self.coll.drop() + write(self.coll, arrow_table) + coll_data = list(self.coll.find({})) + self.compare_arrow_mongodb_data(arrow_table, coll_data) + def test_empty_embedded_array(self): # From INTPYTHON-575. self.coll.drop()