|
19 | 19 | import threading
|
20 | 20 | import unittest
|
21 | 21 | import unittest.mock as mock
|
22 |
| -from datetime import date, datetime |
| 22 | +from datetime import date, datetime, timedelta |
23 | 23 | from pathlib import Path
|
24 | 24 | from test import client_context
|
25 | 25 | from test.utils import AllowListEventListener, NullsTestMixin
|
26 | 26 |
|
27 | 27 | import pyarrow as pa
|
28 | 28 | import pyarrow.json
|
29 | 29 | import pymongo
|
| 30 | +import pytest |
30 | 31 | from bson import Binary, Code, CodecOptions, Decimal128, ObjectId, json_util
|
31 | 32 | from pyarrow import (
|
32 | 33 | Table,
|
|
63 | 64 | ObjectIdType,
|
64 | 65 | )
|
65 | 66 |
|
| 67 | +try: |
| 68 | + import pandas as pd |
| 69 | +except ImportError: |
| 70 | + pd = None |
| 71 | + |
66 | 72 | HERE = Path(__file__).absolute().parent
|
67 | 73 |
|
68 | 74 |
|
@@ -1082,6 +1088,122 @@ def test_decimal128(self):
|
1082 | 1088 | coll_data = list(self.coll.find({}))
|
1083 | 1089 | assert coll_data[0]["data"] == Decimal128(a)
|
1084 | 1090 |
|
| 1091 | + def alltypes_sample(self, size=10000, seed=0, categorical=False): |
| 1092 | + # modified from https://github.com/apache/arrow/blob/main/python/pyarrow/tests/parquet/common.py#L139 |
| 1093 | + import numpy as np |
| 1094 | + import pandas as pd |
| 1095 | + |
| 1096 | + np.random.seed(seed) |
| 1097 | + arrays = { |
| 1098 | + "uint8": np.arange(size, dtype=np.uint8), |
| 1099 | + "uint16": np.arange(size, dtype=np.uint16), |
| 1100 | + "uint32": np.arange(size, dtype=np.uint32), |
| 1101 | + "uint64": np.arange(size, dtype=np.uint64), |
| 1102 | + "int8": np.arange(size, dtype=np.int8), |
| 1103 | + "int16": np.arange(size, dtype=np.int16), |
| 1104 | + "int32": np.arange(size, dtype=np.int32), |
| 1105 | + "int64": np.arange(size, dtype=np.int64), |
| 1106 | + "float16": np.arange(size, dtype=np.float16), |
| 1107 | + "float32": np.arange(size, dtype=np.float32), |
| 1108 | + "float64": np.arange(size, dtype=np.float64), |
| 1109 | + "bool": np.random.randn(size) > 0, |
| 1110 | + "datetime_ms": np.arange("2016-01-01T00:00:00.001", size, dtype="datetime64[ms]"), |
| 1111 | + "datetime_us": np.arange("2016-01-01T00:00:00.000001", size, dtype="datetime64[us]"), |
| 1112 | + "datetime_ns": np.arange("2016-01-01T00:00:00.000000001", size, dtype="datetime64[ns]"), |
| 1113 | + "timedelta": np.arange(size, dtype="timedelta64[s]"), |
| 1114 | + "str": pd.Series([str(x) for x in range(size)]), |
| 1115 | + "empty_str": [""] * size, |
| 1116 | + "str_with_nulls": [None] + [str(x) for x in range(size - 2)] + [None], |
| 1117 | + "null": [None] * size, |
| 1118 | + "null_list": [None] * 2 + [[None] * (x % 4) for x in range(size - 2)], |
| 1119 | + } |
| 1120 | + if categorical: |
| 1121 | + arrays["str_category"] = arrays["str"].astype("category") |
| 1122 | + return pd.DataFrame(arrays) |
| 1123 | + |
| 1124 | + def convert_categorical_columns_to_string(self, table): |
| 1125 | + """ |
| 1126 | + Converts any categorical columns in an Arrow Table into string columns. |
| 1127 | + This preprocessing step ensures compatibility with PyMongoArrow schema validation. |
| 1128 | + """ |
| 1129 | + new_columns = [] |
| 1130 | + for column_name, column in zip(table.column_names, table.columns): |
| 1131 | + if pa.types.is_dictionary(column.type): |
| 1132 | + # Convert dictionary (categorical) columns to string |
| 1133 | + new_columns.append(pa.array(column.combine_chunks().to_pandas(), type=pa.string())) |
| 1134 | + else: |
| 1135 | + # Keep other column types intact |
| 1136 | + new_columns.append(column) |
| 1137 | + # Return a new Arrow Table |
| 1138 | + return pa.Table.from_arrays(new_columns, names=table.column_names) |
| 1139 | + |
| 1140 | + def compare_arrow_mongodb_data(self, arrow_table, mongo_data): |
| 1141 | + """ |
| 1142 | + Compare data types and precision between an Arrow Table and MongoDB documents. |
| 1143 | +
|
| 1144 | + Params: |
| 1145 | + arrow_table (pyarrow.Table): The original Arrow Table before insertion. |
| 1146 | + mongo_data (list): The list of MongoDB documents fetched from the collection. |
| 1147 | +
|
| 1148 | + Raises: |
| 1149 | + AssertionError: If any data type or value doesn't match between Arrow and MongoDB. |
| 1150 | + """ |
| 1151 | + import decimal |
| 1152 | + |
| 1153 | + import numpy as np |
| 1154 | + |
| 1155 | + assert len(mongo_data) == len(arrow_table), "MongoDB data length mismatch with Arrow Table." |
| 1156 | + |
| 1157 | + # Convert Arrow Table to Python dict format for comparison |
| 1158 | + arrow_dict = arrow_table.to_pydict() |
| 1159 | + |
| 1160 | + for row_idx in range(arrow_table.num_rows): |
| 1161 | + mongo_document = mongo_data[row_idx] # Fetch the corresponding MongoDB document |
| 1162 | + |
| 1163 | + for column_name in arrow_table.column_names: |
| 1164 | + arrow_value = arrow_dict[column_name][ |
| 1165 | + row_idx |
| 1166 | + ] # Get the value from the Arrow Table (Python representation) |
| 1167 | + mongo_value = mongo_document.get(column_name) # Get the value from MongoDB |
| 1168 | + |
| 1169 | + if isinstance(arrow_value, decimal.Decimal): |
| 1170 | + assert ( |
| 1171 | + Decimal128(arrow_value).to_decimal() == Decimal128(mongo_value).to_decimal() |
| 1172 | + ), f"Precision loss in decimal field '{column_name}' for row {row_idx}. Expected {arrow_value}, got {mongo_value}." |
| 1173 | + elif isinstance(arrow_value, (np.datetime64, pd.Timestamp, datetime)): |
| 1174 | + arrow_value_rounded = pd.Timestamp(arrow_value).round( |
| 1175 | + "ms" |
| 1176 | + ) # Round to milliseconds |
| 1177 | + assert ( |
| 1178 | + arrow_value_rounded.to_pydatetime() == mongo_value |
| 1179 | + ), f"Datetime mismatch in field '{column_name}' for row {row_idx}. Expected {arrow_value_rounded}, got {mongo_value}." |
| 1180 | + elif isinstance(arrow_value, (list, np.ndarray)): |
| 1181 | + assert ( |
| 1182 | + arrow_value == mongo_value |
| 1183 | + ), f"List mismatch in field '{column_name}' for row {row_idx}. Expected {arrow_value}, got {mongo_value}." |
| 1184 | + elif isinstance(arrow_value, timedelta): |
| 1185 | + assert ( |
| 1186 | + arrow_value == timedelta(seconds=mongo_value) |
| 1187 | + ), f"Timedelta mismatch in field '{column_name}' for row {row_idx}. Expected {arrow_value}, got {mongo_value}." |
| 1188 | + else: |
| 1189 | + assert ( |
| 1190 | + arrow_value == mongo_value |
| 1191 | + ), f"Value mismatch in field '{column_name}' for row {row_idx}. Expected {arrow_value}, got {mongo_value}." |
| 1192 | + |
| 1193 | + def test_all_types(self): |
| 1194 | + """ |
| 1195 | + Test the conversion of all standard data types from Parquet → PyArrow → Python → BSON. |
| 1196 | + """ |
| 1197 | + if pd is None: |
| 1198 | + pytest.skip("Requires pandas.", allow_module_level=True) |
| 1199 | + df = self.alltypes_sample(size=100, seed=42, categorical=True) |
| 1200 | + arrow_table = pa.Table.from_pandas(df) |
| 1201 | + arrow_table = self.convert_categorical_columns_to_string(arrow_table) |
| 1202 | + self.coll.drop() |
| 1203 | + write(self.coll, arrow_table) |
| 1204 | + coll_data = list(self.coll.find({})) |
| 1205 | + self.compare_arrow_mongodb_data(arrow_table, coll_data) |
| 1206 | + |
1085 | 1207 | def test_empty_embedded_array(self):
|
1086 | 1208 | # From INTPYTHON-575.
|
1087 | 1209 | self.coll.drop()
|
|
0 commit comments