Skip to content

feat: update pyiceberg/catalog/hive.py to support hive 4.x.x #2206

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 40 additions & 18 deletions pyiceberg/catalog/hive.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
# specific language governing permissions and limitations
# under the License.
import getpass
import importlib
import logging
import socket
import time
Expand All @@ -32,12 +33,14 @@
)
from urllib.parse import urlparse

from hive_metastore.ThriftHiveMetastore import Client
from hive_metastore.ttypes import (
from hive_metastore.v4.ThriftHiveMetastore import Client
from hive_metastore.v4.ttypes import (
AlreadyExistsException,
CheckLockRequest,
EnvironmentContext,
FieldSchema,
GetTableRequest,
GetTablesRequest,
InvalidOperationException,
LockComponent,
LockLevel,
Expand All @@ -50,9 +53,9 @@
SerDeInfo,
StorageDescriptor,
UnlockRequest,
Database as HiveDatabase,
Table as HiveTable,
)
from hive_metastore.ttypes import Database as HiveDatabase
from hive_metastore.ttypes import Table as HiveTable
from tenacity import retry, retry_if_exception_type, stop_after_attempt, wait_exponential
from thrift.protocol import TBinaryProtocol
from thrift.transport import TSocket, TTransport
Expand Down Expand Up @@ -150,6 +153,9 @@ class _HiveClient:

_transport: TTransport
_ugi: Optional[List[str]]
_hive_version: int = 4
_hms_v3: object
_hms_v4: object

def __init__(
self,
Expand All @@ -163,6 +169,14 @@ def __init__(
self._kerberos_service_name = kerberos_service_name
self._ugi = ugi.split(":") if ugi else None
self._transport = self._init_thrift_transport()
self.hms_v3 = importlib.import_module('hive_metastore.v3.ThriftHiveMetastore')
self.hms_v4 = importlib.import_module('hive_metastore.v4.ThriftHiveMetastore')
self._hive_version = self._get_hive_version()

def _get_hive_version(self) -> int:
with self as open_client:
major, *_ = open_client.getVersion().split('.')
return int(major)

def _init_thrift_transport(self) -> TTransport:
url_parts = urlparse(self._uri)
Expand All @@ -174,7 +188,10 @@ def _init_thrift_transport(self) -> TTransport:

def _client(self) -> Client:
protocol = TBinaryProtocol.TBinaryProtocol(self._transport)
client = Client(protocol)
if self._hive_version < 4:
client: Client = self.hms_v3.Client(protocol)
else:
client: Client = self.hms_v4.Client(protocol)
if self._ugi:
client.set_ugi(*self._ugi)
return client
Expand Down Expand Up @@ -387,11 +404,15 @@ def _create_hive_table(self, open_client: Client, hive_table: HiveTable) -> None
except AlreadyExistsException as e:
raise TableAlreadyExistsError(f"Table {hive_table.dbName}.{hive_table.tableName} already exists") from e

def _get_hive_table(self, open_client: Client, database_name: str, table_name: str) -> HiveTable:
try:
return open_client.get_table(dbname=database_name, tbl_name=table_name)
except NoSuchObjectException as e:
raise NoSuchTableError(f"Table does not exists: {table_name}") from e
def _get_hive_table(self, open_client, *, dbname, tbl_name) -> HiveTable:
if open_client._hive_version < 4:
return open_client.get_table(dbname=dbname, tbl_name=tbl_name)
return open_client.get_table_req(GetTableRequest(dbName=dbname, tblName=tbl_name)).table

def _get_table_objects_by_name(self, open_client, *, dbname, tbl_names) -> list[HiveTable]:
if open_client._hive_version < 4:
return open_client.get_table_objects_by_name(dbname=dbname, tbl_names=tbl_names)
return open_client.get_table_objects_by_name_req(GetTablesRequest(dbName=dbname, tblNames=tbl_names)).tables

def create_table(
self,
Expand Down Expand Up @@ -435,7 +456,7 @@ def create_table(

with self._client as open_client:
self._create_hive_table(open_client, tbl)
hive_table = open_client.get_table(dbname=database_name, tbl_name=table_name)
hive_table: HiveTable = self._get_hive_table(open_client, dbname=database_name, tbl_name=table_name)

return self._convert_hive_into_iceberg(hive_table)

Expand Down Expand Up @@ -465,7 +486,7 @@ def register_table(self, identifier: Union[str, Identifier], metadata_location:
tbl = self._convert_iceberg_into_hive(staged_table)
with self._client as open_client:
self._create_hive_table(open_client, tbl)
hive_table = open_client.get_table(dbname=database_name, tbl_name=table_name)
hive_table: HiveTable = self._get_hive_table(open_client, dbname=database_name, tbl_name=table_name)

return self._convert_hive_into_iceberg(hive_table)

Expand Down Expand Up @@ -538,7 +559,7 @@ def commit_table(
hive_table: Optional[HiveTable]
current_table: Optional[Table]
try:
hive_table = self._get_hive_table(open_client, database_name, table_name)
hive_table = self._get_hive_table(open_client, dbname=database_name, tbl_name=table_name)
current_table = self._convert_hive_into_iceberg(hive_table)
except NoSuchTableError:
hive_table = None
Expand Down Expand Up @@ -612,7 +633,7 @@ def load_table(self, identifier: Union[str, Identifier]) -> Table:
database_name, table_name = self.identifier_to_database_and_table(identifier, NoSuchTableError)

with self._client as open_client:
hive_table = self._get_hive_table(open_client, database_name, table_name)
hive_table = self._get_hive_table(open_client, dbname=database_name, tbl_name=table_name)

return self._convert_hive_into_iceberg(hive_table)

Expand Down Expand Up @@ -656,7 +677,7 @@ def rename_table(self, from_identifier: Union[str, Identifier], to_identifier: U
to_database_name, to_table_name = self.identifier_to_database_and_table(to_identifier)
try:
with self._client as open_client:
tbl = open_client.get_table(dbname=from_database_name, tbl_name=from_table_name)
tbl: HiveTable = self._get_hive_table(open_client, dbname=from_database_name, tbl_name=from_table_name)
tbl.dbName = to_database_name
tbl.tableName = to_table_name
open_client.alter_table_with_environment_context(
Expand Down Expand Up @@ -728,8 +749,9 @@ def list_tables(self, namespace: Union[str, Identifier]) -> List[Identifier]:
with self._client as open_client:
return [
(database_name, table.tableName)
for table in open_client.get_table_objects_by_name(
dbname=database_name, tbl_names=open_client.get_all_tables(db_name=database_name)
for table in self._get_table_objects_by_name(
open_client, dbname=database_name,
tbl_names=open_client.get_all_tables(db_name=database_name)
)
if table.parameters.get(TABLE_TYPE, "").lower() == ICEBERG
]
Expand Down Expand Up @@ -800,7 +822,7 @@ def update_namespace_properties(
if removals:
for key in removals:
if key in parameters:
parameters.pop(key)
parameters[key] = None
removed.add(key)
if updates:
for key, value in updates.items():
Expand Down
59 changes: 36 additions & 23 deletions tests/catalog/test_hive.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@

import pytest
import thrift.transport.TSocket
from hive_metastore.ttypes import (
from hive_metastore.v4.ttypes import (
AlreadyExistsException,
EnvironmentContext,
FieldSchema,
Expand All @@ -39,9 +39,9 @@
SerDeInfo,
SkewedInfo,
StorageDescriptor,
Database as HiveDatabase,
Table as HiveTable,
)
from hive_metastore.ttypes import Database as HiveDatabase
from hive_metastore.ttypes import Table as HiveTable

from pyiceberg.catalog import PropertiesUpdateSummary
from pyiceberg.catalog.hive import (
Expand Down Expand Up @@ -254,6 +254,8 @@ def test_no_uri_supplied() -> None:


def test_check_number_of_namespaces(table_schema_simple: Schema) -> None:
_HiveClient._get_hive_version = MagicMock()
_HiveClient._get_hive_version.return_value = 3
catalog = HiveCatalog(HIVE_CATALOG_NAME, uri=HIVE_METASTORE_FAKE_URL)

with pytest.raises(ValueError):
Expand All @@ -280,7 +282,8 @@ def test_create_table(

catalog._client = MagicMock()
catalog._client.__enter__().create_table.return_value = None
catalog._client.__enter__().get_table.return_value = hive_table
catalog._get_hive_table = MagicMock()
catalog._get_hive_table.return_value = hive_table
catalog._client.__enter__().get_database.return_value = hive_database
catalog.create_table(("default", "table"), schema=table_schema_with_all_types, properties={"owner": "javaberg"})

Expand Down Expand Up @@ -459,7 +462,8 @@ def test_create_table_with_given_location_removes_trailing_slash(

catalog._client = MagicMock()
catalog._client.__enter__().create_table.return_value = None
catalog._client.__enter__().get_table.return_value = hive_table
catalog._get_hive_table = MagicMock()
catalog._get_hive_table.return_value = hive_table
catalog._client.__enter__().get_database.return_value = hive_database
catalog.create_table(
("default", "table"), schema=table_schema_with_all_types, properties={"owner": "javaberg"}, location=f"{location}/"
Expand Down Expand Up @@ -632,8 +636,9 @@ def test_create_v1_table(table_schema_simple: Schema, hive_database: HiveDatabas
catalog = HiveCatalog(HIVE_CATALOG_NAME, uri=HIVE_METASTORE_FAKE_URL)

catalog._client = MagicMock()
catalog._get_hive_table = MagicMock()
catalog._client.__enter__().create_table.return_value = None
catalog._client.__enter__().get_table.return_value = hive_table
catalog._get_hive_table.return_value = hive_table
catalog._client.__enter__().get_database.return_value = hive_database
catalog.create_table(
("default", "table"), schema=table_schema_simple, properties={"owner": "javaberg", "format-version": "1"}
Expand Down Expand Up @@ -684,10 +689,11 @@ def test_load_table(hive_table: HiveTable) -> None:
catalog = HiveCatalog(HIVE_CATALOG_NAME, uri=HIVE_METASTORE_FAKE_URL)

catalog._client = MagicMock()
catalog._client.__enter__().get_table.return_value = hive_table
catalog._get_hive_table = MagicMock()
catalog._get_hive_table.return_value = hive_table
table = catalog.load_table(("default", "new_tabl2e"))

catalog._client.__enter__().get_table.assert_called_with(dbname="default", tbl_name="new_tabl2e")
catalog._get_hive_table.assert_called_with(catalog._client.__enter__(), dbname="default", tbl_name="new_tabl2e")

expected = TableMetadataV2(
location="s3://bucket/test/location",
Expand Down Expand Up @@ -784,11 +790,12 @@ def test_load_table_from_self_identifier(hive_table: HiveTable) -> None:
catalog = HiveCatalog(HIVE_CATALOG_NAME, uri=HIVE_METASTORE_FAKE_URL)

catalog._client = MagicMock()
catalog._client.__enter__().get_table.return_value = hive_table
catalog._get_hive_table = MagicMock()
catalog._get_hive_table.return_value = hive_table
intermediate = catalog.load_table(("default", "new_tabl2e"))
table = catalog.load_table(intermediate.name())

catalog._client.__enter__().get_table.assert_called_with(dbname="default", tbl_name="new_tabl2e")
catalog._get_hive_table.assert_called_with(catalog._client.__enter__(), dbname="default", tbl_name="new_tabl2e")

expected = TableMetadataV2(
location="s3://bucket/test/location",
Expand Down Expand Up @@ -889,7 +896,8 @@ def test_rename_table(hive_table: HiveTable) -> None:
renamed_table.tableName = "new_tabl3e"

catalog._client = MagicMock()
catalog._client.__enter__().get_table.side_effect = [hive_table, renamed_table]
catalog._get_hive_table = MagicMock()
catalog._get_hive_table.side_effect = [hive_table, renamed_table]
catalog._client.__enter__().alter_table_with_environment_context.return_value = None

from_identifier = ("default", "new_tabl2e")
Expand All @@ -898,8 +906,8 @@ def test_rename_table(hive_table: HiveTable) -> None:

assert table.name() == to_identifier

calls = [call(dbname="default", tbl_name="new_tabl2e"), call(dbname="default", tbl_name="new_tabl3e")]
catalog._client.__enter__().get_table.assert_has_calls(calls)
calls = [call(catalog._client.__enter__(), dbname="default", tbl_name="new_tabl2e"), call(catalog._client.__enter__(), dbname="default", tbl_name="new_tabl3e")]
catalog._get_hive_table.assert_has_calls(calls)
catalog._client.__enter__().alter_table_with_environment_context.assert_called_with(
dbname="default",
tbl_name="new_tabl2e",
Expand All @@ -912,25 +920,26 @@ def test_rename_table_from_self_identifier(hive_table: HiveTable) -> None:
catalog = HiveCatalog(HIVE_CATALOG_NAME, uri=HIVE_METASTORE_FAKE_URL)

catalog._client = MagicMock()
catalog._client.__enter__().get_table.return_value = hive_table
catalog._get_hive_table = MagicMock()
catalog._get_hive_table.return_value = hive_table

from_identifier = ("default", "new_tabl2e")
from_table = catalog.load_table(from_identifier)
catalog._client.__enter__().get_table.assert_called_with(dbname="default", tbl_name="new_tabl2e")
catalog._get_hive_table.assert_called_with(catalog._client.__enter__(), dbname="default", tbl_name="new_tabl2e")

renamed_table = copy.deepcopy(hive_table)
renamed_table.dbName = "default"
renamed_table.tableName = "new_tabl3e"

catalog._client.__enter__().get_table.side_effect = [hive_table, renamed_table]
catalog._get_hive_table.side_effect = [hive_table, renamed_table]
catalog._client.__enter__().alter_table_with_environment_context.return_value = None
to_identifier = ("default", "new_tabl3e")
table = catalog.rename_table(from_table.name(), to_identifier)

assert table.name() == to_identifier

calls = [call(dbname="default", tbl_name="new_tabl2e"), call(dbname="default", tbl_name="new_tabl3e")]
catalog._client.__enter__().get_table.assert_has_calls(calls)
calls = [call(catalog._client.__enter__(), dbname="default", tbl_name="new_tabl2e"), call(catalog._client.__enter__(), dbname="default", tbl_name="new_tabl3e")]
catalog._get_hive_table.assert_has_calls(calls)
catalog._client.__enter__().alter_table_with_environment_context.assert_called_with(
dbname="default",
tbl_name="new_tabl2e",
Expand All @@ -943,6 +952,7 @@ def test_rename_table_from_does_not_exists() -> None:
catalog = HiveCatalog(HIVE_CATALOG_NAME, uri=HIVE_METASTORE_FAKE_URL)

catalog._client = MagicMock()
catalog._client.__enter__()._hive_version = 3
catalog._client.__enter__().alter_table_with_environment_context.side_effect = NoSuchObjectException(
message="hive.default.does_not_exists table not found"
)
Expand All @@ -957,6 +967,7 @@ def test_rename_table_to_namespace_does_not_exists() -> None:
catalog = HiveCatalog(HIVE_CATALOG_NAME, uri=HIVE_METASTORE_FAKE_URL)

catalog._client = MagicMock()
catalog._client.__enter__()._hive_version = 3
catalog._client.__enter__().alter_table_with_environment_context.side_effect = InvalidOperationException(
message="Unable to change partition or table. Database default does not exist Check metastore logs for detailed stack.does_not_exists"
)
Expand Down Expand Up @@ -1013,13 +1024,14 @@ def test_list_tables(hive_table: HiveTable) -> None:

catalog._client = MagicMock()
catalog._client.__enter__().get_all_tables.return_value = ["table1", "table2", "table3", "table4"]
catalog._client.__enter__().get_table_objects_by_name.return_value = [tbl1, tbl2, tbl3, tbl4]
catalog._get_table_objects_by_name = MagicMock()
catalog._get_table_objects_by_name.return_value = [tbl1, tbl2, tbl3, tbl4]

got_tables = catalog.list_tables("database")
assert got_tables == [("database", "table1"), ("database", "table2")]
catalog._client.__enter__().get_all_tables.assert_called_with(db_name="database")
catalog._client.__enter__().get_table_objects_by_name.assert_called_with(
dbname="database", tbl_names=["table1", "table2", "table3", "table4"]
catalog._get_table_objects_by_name.assert_called_with(
catalog._client.__enter__(), dbname="database", tbl_names=["table1", "table2", "table3", "table4"]
)


Expand Down Expand Up @@ -1049,7 +1061,8 @@ def test_drop_table_from_self_identifier(hive_table: HiveTable) -> None:
catalog = HiveCatalog(HIVE_CATALOG_NAME, uri=HIVE_METASTORE_FAKE_URL)

catalog._client = MagicMock()
catalog._client.__enter__().get_table.return_value = hive_table
catalog._get_hive_table = MagicMock()
catalog._get_hive_table.return_value = hive_table
table = catalog.load_table(("default", "new_tabl2e"))

catalog._client.__enter__().get_all_databases.return_value = ["namespace1", "namespace2"]
Expand Down Expand Up @@ -1156,7 +1169,7 @@ def test_update_namespace_properties(hive_database: HiveDatabase) -> None:
name="default",
description=None,
locationUri=hive_database.locationUri,
parameters={"label": "core"},
parameters={"test": None, "label": "core"},
privileges=None,
ownerName=None,
ownerType=1,
Expand Down
File renamed without changes.
File renamed without changes.
Loading
Loading