Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions examples/polars/materialization/my_script.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,13 @@
file="./df.json",
combine=df_builder,
),
# materialize the dataframe to an ndjson file
to.ndjson(
dependencies=output_columns,
id="df_to_ndjson",
file="./df.ndjson",
combine=df_builder,
),
to.avro(
dependencies=output_columns,
id="df_to_avro",
Expand Down Expand Up @@ -117,6 +124,7 @@
"df_to_parquet_build_result",
"df_to_feather_build_result",
"df_to_json_build_result",
"df_to_ndjson_build_result",
"df_to_avro_build_result",
"df_to_spreadsheet_build_result",
"df_to_database_build_result",
Expand All @@ -127,6 +135,7 @@
print(additional_outputs["df_to_parquet_build_result"])
print(additional_outputs["df_to_feather_build_result"])
print(additional_outputs["df_to_json_build_result"])
print(additional_outputs["df_to_ndjson_build_result"])
print(additional_outputs["df_to_avro_build_result"])
print(additional_outputs["df_to_spreadsheet_build_result"])
print(additional_outputs["df_to_database_build_result"])
60 changes: 60 additions & 0 deletions hamilton/plugins/polars_post_1_0_0_extensions.py
Original file line number Diff line number Diff line change
Expand Up @@ -564,6 +564,64 @@ def name(cls) -> str:
return "json"


@dataclasses.dataclass
class PolarsNDJSONReader(DataLoader):
"""
Class specifically to handle loading NDJSON (newline-delimited JSON) files with Polars.
Should map to https://docs.pola.rs/api/python/stable/reference/api/polars.read_ndjson.html
"""

source: Union[str, Path, IOBase, bytes]
schema: SchemaDefinition = None
schema_overrides: SchemaDefinition = None

@classmethod
def applicable_types(cls) -> Collection[Type]:
return [DATAFRAME_TYPE]

def _get_loading_kwargs(self):
kwargs = {}
if self.schema is not None:
kwargs["schema"] = self.schema
if self.schema_overrides is not None:
kwargs["schema_overrides"] = self.schema_overrides
return kwargs

def load_data(self, type_: Type) -> Tuple[DATAFRAME_TYPE, Dict[str, Any]]:
df = pl.read_ndjson(self.source, **self._get_loading_kwargs())
metadata = utils.get_file_metadata(self.source)
return df, metadata

@classmethod
def name(cls) -> str:
return "ndjson"


@dataclasses.dataclass
class PolarsNDJSONWriter(DataSaver):
"""
Class specifically to handle saving NDJSON (newline-delimited JSON) files with Polars.
Should map to https://docs.pola.rs/api/python/stable/reference/api/polars.DataFrame.write_ndjson.html
"""

file: Union[IOBase, str, Path]

@classmethod
def applicable_types(cls) -> Collection[Type]:
return [DATAFRAME_TYPE, pl.LazyFrame]

def save_data(self, data: Union[DATAFRAME_TYPE, pl.LazyFrame]) -> Dict[str, Any]:
if isinstance(data, pl.LazyFrame):
data = data.collect()

data.write_ndjson(self.file)
return utils.get_file_and_dataframe_metadata(self.file, data)

@classmethod
def name(cls) -> str:
return "ndjson"


@dataclasses.dataclass
class PolarsSpreadsheetReader(DataLoader):
"""
Expand Down Expand Up @@ -822,6 +880,8 @@ def register_data_loaders():
PolarsAvroWriter,
PolarsJSONReader,
PolarsJSONWriter,
PolarsNDJSONReader,
PolarsNDJSONWriter,
PolarsDatabaseReader,
PolarsDatabaseWriter,
PolarsSpreadsheetReader,
Expand Down
18 changes: 18 additions & 0 deletions tests/plugins/test_polars_extensions.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@
PolarsFeatherWriter,
PolarsJSONReader,
PolarsJSONWriter,
PolarsNDJSONReader,
PolarsNDJSONWriter,
PolarsParquetReader,
PolarsParquetWriter,
PolarsSpreadsheetReader,
Expand Down Expand Up @@ -129,6 +131,22 @@ def test_polars_json(df: pl.DataFrame, tmp_path: pathlib.Path) -> None:
polars.testing.assert_frame_equal(df, df2)


def test_polars_ndjson(df: pl.DataFrame, tmp_path: pathlib.Path) -> None:
file = tmp_path / "test.ndjson"
writer = PolarsNDJSONWriter(file=file)
writer.save_data(df)

reader = PolarsNDJSONReader(source=file)
kwargs2 = reader._get_loading_kwargs()
df2, metadata = reader.load_data(pl.DataFrame)

assert PolarsNDJSONWriter.applicable_types() == [pl.DataFrame, pl.LazyFrame]
assert PolarsNDJSONReader.applicable_types() == [pl.DataFrame]
assert df2.shape == (2, 2)
assert "schema" not in kwargs2
polars.testing.assert_frame_equal(df, df2)


def test_polars_avro(df: pl.DataFrame, tmp_path: pathlib.Path) -> None:
file = tmp_path / "test.avro"

Expand Down
18 changes: 18 additions & 0 deletions tests/plugins/test_polars_lazyframe_extensions.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@
PolarsFeatherWriter,
PolarsJSONReader,
PolarsJSONWriter,
PolarsNDJSONReader,
PolarsNDJSONWriter,
PolarsParquetWriter,
PolarsSpreadsheetReader,
PolarsSpreadsheetWriter,
Expand Down Expand Up @@ -145,6 +147,22 @@ def test_polars_json(df: pl.LazyFrame, tmp_path: pathlib.Path) -> None:
assert_frame_equal(df.collect(), df2)


def test_polars_ndjson(df: pl.LazyFrame, tmp_path: pathlib.Path) -> None:
file = tmp_path / "test.ndjson"
writer = PolarsNDJSONWriter(file=file)
writer.save_data(df)

reader = PolarsNDJSONReader(source=file)
kwargs2 = reader._get_loading_kwargs()
df2, metadata = reader.load_data(pl.DataFrame)

assert PolarsNDJSONWriter.applicable_types() == [pl.DataFrame, pl.LazyFrame]
assert PolarsNDJSONReader.applicable_types() == [pl.DataFrame]
assert df2.shape == (2, 2)
assert "schema" not in kwargs2
assert_frame_equal(df.collect(), df2)


@pytest.mark.skipif(
sys.version_info.major == 3 and sys.version_info.minor == 12,
reason="weird connectorx error on 3.12",
Expand Down
Loading