diff --git a/examples/polars/materialization/my_script.py b/examples/polars/materialization/my_script.py index 01593dd7c..9c8836a0f 100644 --- a/examples/polars/materialization/my_script.py +++ b/examples/polars/materialization/my_script.py @@ -78,6 +78,13 @@ file="./df.json", combine=df_builder, ), + # materialize the dataframe to an ndjson file + to.ndjson( + dependencies=output_columns, + id="df_to_ndjson", + file="./df.ndjson", + combine=df_builder, + ), to.avro( dependencies=output_columns, id="df_to_avro", @@ -117,6 +124,7 @@ "df_to_parquet_build_result", "df_to_feather_build_result", "df_to_json_build_result", + "df_to_ndjson_build_result", "df_to_avro_build_result", "df_to_spreadsheet_build_result", "df_to_database_build_result", @@ -127,6 +135,7 @@ print(additional_outputs["df_to_parquet_build_result"]) print(additional_outputs["df_to_feather_build_result"]) print(additional_outputs["df_to_json_build_result"]) +print(additional_outputs["df_to_ndjson_build_result"]) print(additional_outputs["df_to_avro_build_result"]) print(additional_outputs["df_to_spreadsheet_build_result"]) print(additional_outputs["df_to_database_build_result"]) diff --git a/hamilton/plugins/polars_post_1_0_0_extensions.py b/hamilton/plugins/polars_post_1_0_0_extensions.py index 056a87dce..25c0f6826 100644 --- a/hamilton/plugins/polars_post_1_0_0_extensions.py +++ b/hamilton/plugins/polars_post_1_0_0_extensions.py @@ -564,6 +564,64 @@ def name(cls) -> str: return "json" +@dataclasses.dataclass +class PolarsNDJSONReader(DataLoader): + """ + Class specifically to handle loading NDJSON (newline-delimited JSON) files with Polars. + Should map to https://docs.pola.rs/api/python/stable/reference/api/polars.read_ndjson.html + """ + + source: Union[str, Path, IOBase, bytes] + schema: SchemaDefinition = None + schema_overrides: SchemaDefinition = None + + @classmethod + def applicable_types(cls) -> Collection[Type]: + return [DATAFRAME_TYPE] + + def _get_loading_kwargs(self): + kwargs = {} + if self.schema is not None: + kwargs["schema"] = self.schema + if self.schema_overrides is not None: + kwargs["schema_overrides"] = self.schema_overrides + return kwargs + + def load_data(self, type_: Type) -> Tuple[DATAFRAME_TYPE, Dict[str, Any]]: + df = pl.read_ndjson(self.source, **self._get_loading_kwargs()) + metadata = utils.get_file_metadata(self.source) + return df, metadata + + @classmethod + def name(cls) -> str: + return "ndjson" + + +@dataclasses.dataclass +class PolarsNDJSONWriter(DataSaver): + """ + Class specifically to handle saving NDJSON (newline-delimited JSON) files with Polars. + Should map to https://docs.pola.rs/api/python/stable/reference/api/polars.DataFrame.write_ndjson.html + """ + + file: Union[IOBase, str, Path] + + @classmethod + def applicable_types(cls) -> Collection[Type]: + return [DATAFRAME_TYPE, pl.LazyFrame] + + def save_data(self, data: Union[DATAFRAME_TYPE, pl.LazyFrame]) -> Dict[str, Any]: + if isinstance(data, pl.LazyFrame): + data = data.collect() + + data.write_ndjson(self.file) + return utils.get_file_and_dataframe_metadata(self.file, data) + + @classmethod + def name(cls) -> str: + return "ndjson" + + @dataclasses.dataclass class PolarsSpreadsheetReader(DataLoader): """ @@ -822,6 +880,8 @@ def register_data_loaders(): PolarsAvroWriter, PolarsJSONReader, PolarsJSONWriter, + PolarsNDJSONReader, + PolarsNDJSONWriter, PolarsDatabaseReader, PolarsDatabaseWriter, PolarsSpreadsheetReader, diff --git a/tests/plugins/test_polars_extensions.py b/tests/plugins/test_polars_extensions.py index 796e413ce..200373075 100644 --- a/tests/plugins/test_polars_extensions.py +++ b/tests/plugins/test_polars_extensions.py @@ -37,6 +37,8 @@ PolarsFeatherWriter, PolarsJSONReader, PolarsJSONWriter, + PolarsNDJSONReader, + PolarsNDJSONWriter, PolarsParquetReader, PolarsParquetWriter, PolarsSpreadsheetReader, @@ -129,6 +131,22 @@ def test_polars_json(df: pl.DataFrame, tmp_path: pathlib.Path) -> None: polars.testing.assert_frame_equal(df, df2) +def test_polars_ndjson(df: pl.DataFrame, tmp_path: pathlib.Path) -> None: + file = tmp_path / "test.ndjson" + writer = PolarsNDJSONWriter(file=file) + writer.save_data(df) + + reader = PolarsNDJSONReader(source=file) + kwargs2 = reader._get_loading_kwargs() + df2, metadata = reader.load_data(pl.DataFrame) + + assert PolarsNDJSONWriter.applicable_types() == [pl.DataFrame, pl.LazyFrame] + assert PolarsNDJSONReader.applicable_types() == [pl.DataFrame] + assert df2.shape == (2, 2) + assert "schema" not in kwargs2 + polars.testing.assert_frame_equal(df, df2) + + def test_polars_avro(df: pl.DataFrame, tmp_path: pathlib.Path) -> None: file = tmp_path / "test.avro" diff --git a/tests/plugins/test_polars_lazyframe_extensions.py b/tests/plugins/test_polars_lazyframe_extensions.py index 71715e6d8..a35929654 100644 --- a/tests/plugins/test_polars_lazyframe_extensions.py +++ b/tests/plugins/test_polars_lazyframe_extensions.py @@ -37,6 +37,8 @@ PolarsFeatherWriter, PolarsJSONReader, PolarsJSONWriter, + PolarsNDJSONReader, + PolarsNDJSONWriter, PolarsParquetWriter, PolarsSpreadsheetReader, PolarsSpreadsheetWriter, @@ -145,6 +147,22 @@ def test_polars_json(df: pl.LazyFrame, tmp_path: pathlib.Path) -> None: assert_frame_equal(df.collect(), df2) +def test_polars_ndjson(df: pl.LazyFrame, tmp_path: pathlib.Path) -> None: + file = tmp_path / "test.ndjson" + writer = PolarsNDJSONWriter(file=file) + writer.save_data(df) + + reader = PolarsNDJSONReader(source=file) + kwargs2 = reader._get_loading_kwargs() + df2, metadata = reader.load_data(pl.DataFrame) + + assert PolarsNDJSONWriter.applicable_types() == [pl.DataFrame, pl.LazyFrame] + assert PolarsNDJSONReader.applicable_types() == [pl.DataFrame] + assert df2.shape == (2, 2) + assert "schema" not in kwargs2 + assert_frame_equal(df.collect(), df2) + + @pytest.mark.skipif( sys.version_info.major == 3 and sys.version_info.minor == 12, reason="weird connectorx error on 3.12",