Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 39 additions & 2 deletions discoverx/explorer.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import concurrent.futures
import copy
import re
from typing import Optional, List
import itertools as it
from typing import Optional, List, Callable, Any
from discoverx import logging
from discoverx.common import helper
from discoverx.discovery import Discovery
Expand Down Expand Up @@ -165,7 +166,7 @@ def scan(
discover.scan(rules=rules, sample_size=sample_size, what_if=what_if)
return discover

def map(self, f) -> list[any]:
def map(self, f: Callable[[TableInfo], Any]) -> list[Any]:
"""Runs a function for each table in the data explorer

Args:
Expand Down Expand Up @@ -197,6 +198,42 @@ def map(self, f) -> list[any]:

return res

def map_chunked(self, f: Callable[[TableInfo, int], Any], tables_per_chunk: int, **kwargs) -> list[Any]:
"""Runs a function for each table in the data explorer

Args:
f (function): The function to run. The function should accept either a list of TableInfo objects as input and return a list of any object as output.

Returns:
list[any]: A list of the results of running the function for each table
"""
res = []
table_list = self._info_fetcher.get_tables_info(
self._catalogs,
self._schemas,
self._tables,
self._having_columns,
self._with_tags,
)
with concurrent.futures.ThreadPoolExecutor(max_workers=self._max_concurrency) as executor:
# Submit tasks to the thread pool
table_list = iter(table_list)
futures = []
while item := [
executor.submit(f, table_chunk, **kwargs) for table_chunk in it.islice(table_list, tables_per_chunk)
]:
futures.extend(item)

# Process completed tasks
for future in concurrent.futures.as_completed(futures):
result = future.result()
if result is not None:
res.append(result)

logger.debug("Finished lakehouse map_chunked task")

return res


class DataExplorerActions:
def __init__(
Expand Down
36 changes: 36 additions & 0 deletions tests/unit/explorer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,42 @@ def test_map(spark, info_fetcher):
assert result[0].tags == None


def test_map_chunked_1(spark, info_fetcher):
data_explorer = DataExplorer("*.default.tb_1", spark, info_fetcher)
result = data_explorer.map_chunked(lambda table_info: table_info, 10)
assert len(result) == 1
assert result[0].table == "tb_1"
assert result[0].schema == "default"
assert result[0].catalog == None
assert result[0].tags == None


def test_map_chunked_2(spark, info_fetcher):
def check_result(result):
for res in result:
assert res.table in ["tb_1", "tb_2", "tb_all_types"]
if res.table == "tb_1":
assert res.schema == "default"
assert res.catalog == None
assert res.tags == None
elif res.table == "tb_2":
assert res.schema == "default"
assert res.catalog == None
assert res.tags == None
else:
assert res.schema == "default"
assert res.catalog == "hive_metastore"
assert res.tags == None

data_explorer = DataExplorer("*.default.*", spark, info_fetcher)
result = data_explorer.map_chunked(lambda table_info: table_info, 10)
assert len(result) == 3
check_result(result)
result2 = data_explorer.map_chunked(lambda table_info: table_info, 2)
assert len(result2) == 3
check_result(result2)


def test_map_with_tags(spark, info_fetcher):
data_explorer = DataExplorer("*.default.tb_1", spark, info_fetcher).with_tags()
result = data_explorer.map(lambda table_info: table_info)
Expand Down