From 70737f15e004787284b1622d307c6a4c54815c38 Mon Sep 17 00:00:00 2001 From: Liran Bareket Date: Tue, 8 Oct 2024 13:41:30 -0400 Subject: [PATCH 1/4] Split CSV files. Added Integration Test. --- src/databricks/labs/blueprint/installation.py | 28 +++++++++++++++++++ tests/integration/test_installation.py | 15 +++++++++- 2 files changed, 42 insertions(+), 1 deletion(-) diff --git a/src/databricks/labs/blueprint/installation.py b/src/databricks/labs/blueprint/installation.py index 71db98f..1b9d1bf 100644 --- a/src/databricks/labs/blueprint/installation.py +++ b/src/databricks/labs/blueprint/installation.py @@ -40,6 +40,8 @@ __all__ = ["Installation", "MockInstallation", "IllegalState", "NotInstalled", "SerdeError"] +FILE_SIZE_LIMIT: int = 1024 * 1024 * 10 + class IllegalState(ValueError): pass @@ -358,8 +360,34 @@ def _overwrite_content(self, filename: str, as_dict: Json, type_ref: type): raise KeyError(f"Unknown extension: {extension}") logger.debug(f"Converting {type_ref.__name__} into {extension.upper()} format") raw = converters[extension](as_dict, type_ref) + if extension == "csv": + split = self._split_content(raw) + if len(split) > 1: + for i, chunk in enumerate(split): + self.upload(f"{filename[0:-4]}.{i + 1}.csv", chunk) + return + + # Check if the file is more than 10MB + if len(raw) > FILE_SIZE_LIMIT: + raise ValueError(f"File size too large: {len(raw)} bytes") + self.upload(filename, raw) + @staticmethod + def _split_content(raw: bytes) -> list[bytes]: + """The `_split_content` method is a private method that is used to split the raw bytes of a file into chunks + that are less than 10MB in size. This method is called by the `_overwrite_content` method.""" + chunks = [] + chunk = b"" + lines = raw.split(b"\n") + for line in lines: + if len(chunk) + len(line) > FILE_SIZE_LIMIT: + chunks.append(chunk) + chunk = lines[0] + b"\n" + chunk += line + b"\n" + chunks.append(chunk) + return chunks + @staticmethod def _global_installation(product): """The `_global_installation` method is a private method that is used to determine the installation folder diff --git a/tests/integration/test_installation.py b/tests/integration/test_installation.py index 5d9bede..bdf1bb9 100644 --- a/tests/integration/test_installation.py +++ b/tests/integration/test_installation.py @@ -3,7 +3,7 @@ import pytest from databricks.sdk.errors import PermissionDenied from databricks.sdk.service.provisioning import Workspace - +from databricks.sdk.service.catalog import TableInfo from databricks.labs.blueprint.installation import Installation @@ -73,6 +73,19 @@ def test_saving_list_of_dataclasses_to_csv(new_installation): assert len(loaded) == 2 +def test_saving_list_of_dataclasses_to_multiple_csvs(new_installation): + tables: list[TableInfo] = [] + for i in range(500000): + tables.append(TableInfo(name=f"long_table_name_{i}", schema_name="very_long_schema_name")) + new_installation.save( + tables, + filename="many_tables_test.csv", + ) + + loaded = new_installation.load(list[Workspace], filename="many_tables_test.1.csv") + assert len(loaded) > 100 + + @pytest.mark.parametrize( "ext,magic", [ From c011003f5f91ffdd5599299d2395a044a67bf89a Mon Sep 17 00:00:00 2001 From: Liran Bareket Date: Wed, 9 Oct 2024 10:49:23 -0400 Subject: [PATCH 2/4] Added more efficient split --- src/databricks/labs/blueprint/installation.py | 44 ++++++++++++------- tests/integration/test_installation.py | 2 +- 2 files changed, 28 insertions(+), 18 deletions(-) diff --git a/src/databricks/labs/blueprint/installation.py b/src/databricks/labs/blueprint/installation.py index 1b9d1bf..5ef6901 100644 --- a/src/databricks/labs/blueprint/installation.py +++ b/src/databricks/labs/blueprint/installation.py @@ -350,7 +350,7 @@ def _overwrite_content(self, filename: str, as_dict: Json, type_ref: type): The `as_dict` argument is the dictionary representation of the object that is to be written to the file. The `type_ref` argument is the type of the object that is being saved.""" - converters: dict[str, Callable[[Any, type], bytes]] = { + converters: dict[str, Callable[[Any, type], list[bytes]]] = { "json": self._dump_json, "yml": self._dump_yaml, "csv": self._dump_csv, @@ -359,19 +359,16 @@ def _overwrite_content(self, filename: str, as_dict: Json, type_ref: type): if extension not in converters: raise KeyError(f"Unknown extension: {extension}") logger.debug(f"Converting {type_ref.__name__} into {extension.upper()} format") - raw = converters[extension](as_dict, type_ref) - if extension == "csv": - split = self._split_content(raw) - if len(split) > 1: - for i, chunk in enumerate(split): - self.upload(f"{filename[0:-4]}.{i + 1}.csv", chunk) - return - + raws = converters[extension](as_dict, type_ref) + if len(raws) > 1: + for i, raw in enumerate(raws): + self.upload(f"{filename[0:-4]}.{i + 1}.csv", raw) + return # Check if the file is more than 10MB - if len(raw) > FILE_SIZE_LIMIT: + if len(raws[0]) > FILE_SIZE_LIMIT: raise ValueError(f"File size too large: {len(raw)} bytes") - self.upload(filename, raw) + self.upload(filename, raws[0]) @staticmethod def _split_content(raw: bytes) -> list[bytes]: @@ -775,19 +772,19 @@ def _explain_why(type_ref: type, path: list[str], raw: Any) -> str: return f'{".".join(path)}: not a {type_ref.__name__}: {raw}' @staticmethod - def _dump_json(as_dict: Json, _: type) -> bytes: + def _dump_json(as_dict: Json, _: type) -> list[bytes]: """The `_dump_json` method is a private method that is used to serialize a dictionary to a JSON string. This method is called by the `save` method.""" - return json.dumps(as_dict, indent=2).encode("utf8") + return [json.dumps(as_dict, indent=2).encode("utf8")] @staticmethod - def _dump_yaml(raw: Json, _: type) -> bytes: + def _dump_yaml(raw: Json, _: type) -> list[bytes]: """The `_dump_yaml` method is a private method that is used to serialize a dictionary to a YAML string. This method is called by the `save` method.""" try: from yaml import dump # pylint: disable=import-outside-toplevel - return dump(raw).encode("utf8") + return [dump(raw).encode("utf8")] except ImportError as err: raise SyntaxError("PyYAML is not installed. Fix: pip install databricks-labs-blueprint[yaml]") from err @@ -809,9 +806,10 @@ def _load_yaml(raw: BinaryIO) -> Json: raise SyntaxError("PyYAML is not installed. Fix: pip install databricks-labs-blueprint[yaml]") from err @staticmethod - def _dump_csv(raw: list[Json], type_ref: type) -> bytes: + def _dump_csv(raw: list[Json], type_ref: type) -> list[bytes]: """The `_dump_csv` method is a private method that is used to serialize a list of dictionaries to a CSV string. This method is called by the `save` method.""" + raws = [] type_args = get_args(type_ref) if not type_args: raise SerdeError(f"Writing CSV is only supported for lists. Got {type_ref}") @@ -832,9 +830,21 @@ def _dump_csv(raw: list[Json], type_ref: type) -> bytes: writer = csv.DictWriter(buffer, field_names, dialect="excel") writer.writeheader() for as_dict in raw: + # Check if the buffer + the current row is over the file size limit + before_pos = buffer.tell() writer.writerow(as_dict) + if buffer.tell() > FILE_SIZE_LIMIT: + buffer.seek(before_pos) + buffer.truncate() + raws.append(buffer.getvalue().encode("utf8")) + buffer = io.StringIO() + writer = csv.DictWriter(buffer, field_names, dialect="excel") + writer.writeheader() + writer.writerow(as_dict) + buffer.seek(0) - return buffer.read().encode("utf8") + raws.append(buffer.getvalue().encode("utf8")) + return raws @staticmethod def _load_csv(raw: BinaryIO) -> list[Json]: diff --git a/tests/integration/test_installation.py b/tests/integration/test_installation.py index bdf1bb9..074f42e 100644 --- a/tests/integration/test_installation.py +++ b/tests/integration/test_installation.py @@ -82,7 +82,7 @@ def test_saving_list_of_dataclasses_to_multiple_csvs(new_installation): filename="many_tables_test.csv", ) - loaded = new_installation.load(list[Workspace], filename="many_tables_test.1.csv") + loaded = new_installation.load(list[Workspace], filename="many_tables_test.csv") assert len(loaded) > 100 From d0820389e9e7ef48c1f5dcf30744786821336293 Mon Sep 17 00:00:00 2001 From: Liran Bareket Date: Wed, 9 Oct 2024 12:09:17 -0400 Subject: [PATCH 3/4] Added support for read. --- src/databricks/labs/blueprint/installation.py | 33 +++++++++++++++---- tests/integration/test_installation.py | 2 +- 2 files changed, 28 insertions(+), 7 deletions(-) diff --git a/src/databricks/labs/blueprint/installation.py b/src/databricks/labs/blueprint/installation.py index 5ef6901..829379f 100644 --- a/src/databricks/labs/blueprint/installation.py +++ b/src/databricks/labs/blueprint/installation.py @@ -134,6 +134,10 @@ def check_folder(install_folder: str) -> Installation | None: tasks.append(functools.partial(check_folder, service_principal_folder)) return Threads.strict(f"finding {product} installations", tasks) + @staticmethod + def extension(filename): + return filename.split(".")[-1] + @classmethod def load_local(cls, type_ref: type[T], file: Path) -> T: """Loads a typed file from the local file system.""" @@ -355,7 +359,7 @@ def _overwrite_content(self, filename: str, as_dict: Json, type_ref: type): "yml": self._dump_yaml, "csv": self._dump_csv, } - extension = filename.split(".")[-1] + extension = self.extension(filename) if extension not in converters: raise KeyError(f"Unknown extension: {extension}") logger.debug(f"Converting {type_ref.__name__} into {extension.upper()} format") @@ -402,17 +406,34 @@ def _unmarshal_type(cls, as_dict, filename, type_ref): as_dict = cls._migrate_file_format(type_ref, expected_version, as_dict, filename) return cls._unmarshal(as_dict, [], type_ref) - def _load_content(self, filename: str) -> Json: + def _load_content(self, filename: str) -> Json | list[Json]: """The `_load_content` method is a private method that is used to load the contents of a file from WorkspaceFS as a dictionary. This method is called by the `load` method.""" with self._lock: # TODO: check how to make this fail fast during unit testing, otherwise # this currently hangs with the real installation class and mocked workspace client - with self._ws.workspace.download(f"{self.install_folder()}/{filename}") as f: - return self._convert_content(filename, f) + try: + with self._ws.workspace.download(f"{self.install_folder()}/{filename}") as f: + return self._convert_content(filename, f) + except NotFound: + # If the file is not found, check if it is a multi-part csv file + if self.extension(filename) != "csv": + raise + current_part = 1 + content = [] + try: + while True: + with self._ws.workspace.download(f"{self.install_folder()}/{filename[0:-4]}.{current_part}.csv") as f: + content += self._convert_content(filename, f) + current_part += 1 + except NotFound: + if current_part == 1: + raise + return content + @classmethod - def _convert_content(cls, filename: str, raw: BinaryIO) -> Json: + def _convert_content(cls, filename: str, raw: BinaryIO) -> Json|list[Json]: """The `_convert_content` method is a private method that is used to convert the raw bytes of a file to a dictionary. This method is called by the `_load_content` method.""" converters: dict[str, Callable[[BinaryIO], Any]] = { @@ -420,7 +441,7 @@ def _convert_content(cls, filename: str, raw: BinaryIO) -> Json: "yml": cls._load_yaml, "csv": cls._load_csv, } - extension = filename.split(".")[-1] + extension = cls.extension(filename) if extension not in converters: raise KeyError(f"Unknown extension: {extension}") try: diff --git a/tests/integration/test_installation.py b/tests/integration/test_installation.py index 074f42e..e8ba2ad 100644 --- a/tests/integration/test_installation.py +++ b/tests/integration/test_installation.py @@ -83,7 +83,7 @@ def test_saving_list_of_dataclasses_to_multiple_csvs(new_installation): ) loaded = new_installation.load(list[Workspace], filename="many_tables_test.csv") - assert len(loaded) > 100 + assert len(loaded) == 500000 @pytest.mark.parametrize( From cc9d62930296dff12a3a62945c4d06a09f316c3d Mon Sep 17 00:00:00 2001 From: Liran Bareket Date: Wed, 9 Oct 2024 12:49:21 -0400 Subject: [PATCH 4/4] Cleaned up code --- src/databricks/labs/blueprint/installation.py | 34 +++++++------------ tests/integration/test_installation.py | 3 +- 2 files changed, 14 insertions(+), 23 deletions(-) diff --git a/src/databricks/labs/blueprint/installation.py b/src/databricks/labs/blueprint/installation.py index 829379f..5a05bf9 100644 --- a/src/databricks/labs/blueprint/installation.py +++ b/src/databricks/labs/blueprint/installation.py @@ -370,25 +370,10 @@ def _overwrite_content(self, filename: str, as_dict: Json, type_ref: type): return # Check if the file is more than 10MB if len(raws[0]) > FILE_SIZE_LIMIT: - raise ValueError(f"File size too large: {len(raw)} bytes") + raise ValueError(f"File size too large: {len(raws[0])} bytes") self.upload(filename, raws[0]) - @staticmethod - def _split_content(raw: bytes) -> list[bytes]: - """The `_split_content` method is a private method that is used to split the raw bytes of a file into chunks - that are less than 10MB in size. This method is called by the `_overwrite_content` method.""" - chunks = [] - chunk = b"" - lines = raw.split(b"\n") - for line in lines: - if len(chunk) + len(line) > FILE_SIZE_LIMIT: - chunks.append(chunk) - chunk = lines[0] + b"\n" - chunk += line + b"\n" - chunks.append(chunk) - return chunks - @staticmethod def _global_installation(product): """The `_global_installation` method is a private method that is used to determine the installation folder @@ -420,20 +405,25 @@ def _load_content(self, filename: str) -> Json | list[Json]: if self.extension(filename) != "csv": raise current_part = 1 - content = [] + content: list[Json] = [] try: while True: - with self._ws.workspace.download(f"{self.install_folder()}/{filename[0:-4]}.{current_part}.csv") as f: - content += self._convert_content(filename, f) - current_part += 1 + with self._ws.workspace.download( + f"{self.install_folder()}/{filename[0:-4]}.{current_part}.csv" + ) as f: + converted_content = self._convert_content(filename, f) + # check if converted_content is a list + if isinstance(converted_content, list): + content += converted_content + else: + content.append(converted_content) except NotFound: if current_part == 1: raise return content - @classmethod - def _convert_content(cls, filename: str, raw: BinaryIO) -> Json|list[Json]: + def _convert_content(cls, filename: str, raw: BinaryIO) -> Json | list[Json]: """The `_convert_content` method is a private method that is used to convert the raw bytes of a file to a dictionary. This method is called by the `_load_content` method.""" converters: dict[str, Callable[[BinaryIO], Any]] = { diff --git a/tests/integration/test_installation.py b/tests/integration/test_installation.py index e8ba2ad..891151c 100644 --- a/tests/integration/test_installation.py +++ b/tests/integration/test_installation.py @@ -2,8 +2,9 @@ import pytest from databricks.sdk.errors import PermissionDenied -from databricks.sdk.service.provisioning import Workspace from databricks.sdk.service.catalog import TableInfo +from databricks.sdk.service.provisioning import Workspace + from databricks.labs.blueprint.installation import Installation