Skip to content

Commit 6a9997d

Browse files
committed
[providers] Consolidate GCP authentication in Hashicorp Vault
This commit unifies the GCP authentication logic in the Hashicorp Vault secrets backend. The 'gcp' authentication type now supports both service account key-based authentication and Application Default Credentials (ADC) automatically. Changes: - Consolidate 'gcp' authentication to handle both keys and managed identities. - Automatically determine service account email from credentials, with a fallback to the GCE/GKE metadata server. - Fix a bug where the 'sub' claim in the signed JWT payload was incorrectly set to the credentials object instead of the service account email. - Update VaultHook and VaultBackend docstrings and initialization logic. - Add unit tests for ADC-based GCP authentication and update existing tests.
1 parent 9ed8a16 commit 6a9997d

File tree

3 files changed

+84
-15
lines changed

3 files changed

+84
-15
lines changed

providers/hashicorp/src/airflow/providers/hashicorp/_internal_client/vault_client.py

Lines changed: 19 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -348,24 +348,35 @@ def _auth_gcp(self, _client: hvac.Client) -> None:
348348
import json
349349
import time
350350

351-
type = getattr(credentials, "type", None)
352-
if type != "service_account":
353-
raise VaultError("Credentials are not of type service_account")
354-
355-
service_account_email = getattr(credentials, "client_email", None)
351+
# Determine service account email
352+
service_account_email = getattr(credentials, "service_account_email", None)
353+
if not service_account_email or not isinstance(service_account_email, str):
354+
service_account_email = getattr(credentials, "client_email", None)
355+
356+
if not service_account_email or not isinstance(service_account_email, str):
357+
# Fallback for Compute Engine credentials if email is not yet populated
358+
try:
359+
from google.auth import compute_engine
360+
if isinstance(credentials, compute_engine.Credentials):
361+
if not getattr(credentials, "service_account_email", None):
362+
from google.auth import transport
363+
credentials.refresh(transport.requests.Request())
364+
service_account_email = credentials.service_account_email
365+
except Exception:
366+
pass
367+
356368
if not service_account_email:
357369
raise VaultError("Could not determine service account email from credentials")
358370

359371
# Generate a payload for subsequent "signJwt()" call
360-
# Reference: https://googleapis.dev/python/google-auth/latest/reference/google.auth.jwt.html#google.auth.jwt.Credentials
361372
now = int(time.time())
362373
expires = now + 900 # 15 mins in seconds, can't be longer.
363-
payload = {"iat": now, "exp": expires, "sub": credentials, "aud": f"vault/{self.role_id}"}
374+
payload = {"iat": now, "exp": expires, "sub": service_account_email, "aud": f"vault/{self.role_id}"}
364375
body = {"payload": json.dumps(payload)}
365376
name = f"projects/{project_id}/serviceAccounts/{service_account_email}"
366377

367378
# Perform the GCP API call
368-
import googleapiclient
379+
import googleapiclient.discovery
369380
iam = googleapiclient.discovery.build("iam", "v1", credentials=credentials)
370381
request = iam.projects().serviceAccounts().signJwt(name=name, body=body)
371382
resp = request.execute()

providers/hashicorp/src/airflow/providers/hashicorp/hooks/vault.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ class VaultHook(BaseHook):
7979
8080
:param vault_conn_id: The id of the connection to use
8181
:param auth_type: Authentication Type for the Vault. Default is ``token``. Available values are:
82-
('approle', 'github', 'gcp', 'jwt', 'kubernetes', 'ldap', 'token', 'userpass')
82+
('approle', 'aws_iam', 'azure', 'github', 'gcp', 'jwt', 'kubernetes', 'ldap', 'radius', 'token', 'userpass')
8383
:param auth_mount_point: It can be used to define mount_point for authentication chosen
8484
Default depends on the authentication method used.
8585
:param kv_engine_version: Select the version of the engine to run (``1`` or ``2``). Defaults to
@@ -152,13 +152,13 @@ def __init__(
152152
if kwargs:
153153
client_kwargs = merge_dicts(client_kwargs, kwargs)
154154

155-
if auth_type == "approle" and self.connection.login:
155+
if auth_type in ("approle", "gcp") and self.connection.login:
156156
role_id = self.connection.login
157157

158-
if auth_type == "aws_iam":
158+
if auth_type in ("aws_iam", "gcp"):
159159
if not role_id:
160160
role_id = self.connection.extra_dejson.get("role_id")
161-
if not region:
161+
if not region and auth_type == "aws_iam":
162162
region = self.connection.extra_dejson.get("region")
163163

164164
azure_resource, azure_tenant_id = (

providers/hashicorp/tests/unit/hashicorp/_internal_client/test_vault_client.py

Lines changed: 61 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -263,7 +263,10 @@ def test_gcp(self, mock_google_build, mock_hvac_client, mock_get_credentials, mo
263263
mock_client = mock.MagicMock()
264264
mock_hvac_client.return_value = mock_client
265265
mock_get_scopes.return_value = ["scope1", "scope2"]
266-
mock_get_credentials.return_value = ('{"client_email": "service_account_email"}', "project_id")
266+
267+
mock_credentials = mock.MagicMock()
268+
mock_credentials.client_email = "service_account_email"
269+
mock_get_credentials.return_value = (mock_credentials, "project_id")
267270

268271
# Mock the current time to use for iat and exp
269272
current_time = int(time.time())
@@ -315,12 +318,59 @@ def mocked_json_dumps(payload):
315318
# Assert iat and exp values are as expected
316319
assert payload["iat"] == iat
317320
assert payload["exp"] == exp
321+
assert payload["sub"] == "service_account_email"
318322
assert abs(payload["exp"] - (payload["iat"] + 3600)) < 10 # Validate exp is 3600 seconds after iat
319323

320324
client.auth.gcp.login.assert_called_with(role="role", jwt="mocked_jwt")
321325
client.is_authenticated.assert_called_with()
322326
assert vault_client.kv_engine_version == 2
323327

328+
@mock.patch("airflow.providers.google.cloud.utils.credentials_provider._get_scopes")
329+
@mock.patch("airflow.providers.google.cloud.utils.credentials_provider.get_credentials_and_project_id")
330+
@mock.patch("airflow.providers.hashicorp._internal_client.vault_client.hvac.Client")
331+
@mock.patch("googleapiclient.discovery.build")
332+
def test_gcp_adc(self, mock_google_build, mock_hvac_client, mock_get_credentials, mock_get_scopes):
333+
mock_client = mock.MagicMock()
334+
mock_hvac_client.return_value = mock_client
335+
mock_get_scopes.return_value = ["scope1", "scope2"]
336+
337+
mock_credentials = mock.MagicMock()
338+
mock_credentials.service_account_email = "service_account_email"
339+
mock_get_credentials.return_value = (mock_credentials, "project_id")
340+
341+
mock_sign_jwt = (
342+
mock_google_build.return_value.projects.return_value.serviceAccounts.return_value.signJwt
343+
)
344+
mock_sign_jwt.return_value.execute.return_value = {"signedJwt": "mocked_jwt"}
345+
346+
vault_client = _VaultClient(
347+
auth_type="gcp",
348+
gcp_scopes="scope1,scope2",
349+
role_id="role",
350+
url="http://localhost:8180",
351+
session=None,
352+
)
353+
354+
client = vault_client.client # Trigger the Vault client creation
355+
356+
# Validate that the HVAC client and other mocks are called correctly
357+
mock_hvac_client.assert_called_with(url="http://localhost:8180", session=None)
358+
mock_get_scopes.assert_called_with("scope1,scope2")
359+
mock_get_credentials.assert_called_with(
360+
key_path=None, keyfile_dict=None, scopes=["scope1", "scope2"]
361+
)
362+
363+
# Extract the arguments passed to the mocked signJwt API
364+
args, kwargs = mock_sign_jwt.call_args
365+
payload = json.loads(kwargs["body"]["payload"])
366+
367+
# Assert sub is correctly set to service account email
368+
assert payload["sub"] == "service_account_email"
369+
370+
client.auth.gcp.login.assert_called_with(role="role", jwt="mocked_jwt")
371+
client.is_authenticated.assert_called_with()
372+
assert vault_client.kv_engine_version == 2
373+
324374
@mock.patch("airflow.providers.google.cloud.utils.credentials_provider._get_scopes")
325375
@mock.patch("airflow.providers.google.cloud.utils.credentials_provider.get_credentials_and_project_id")
326376
@mock.patch("airflow.providers.hashicorp._internal_client.vault_client.hvac.Client")
@@ -331,7 +381,10 @@ def test_gcp_different_auth_mount_point(
331381
mock_client = mock.MagicMock()
332382
mock_hvac_client.return_value = mock_client
333383
mock_get_scopes.return_value = ["scope1", "scope2"]
334-
mock_get_credentials.return_value = ('{"client_email": "service_account_email"}', "project_id")
384+
385+
mock_credentials = mock.MagicMock()
386+
mock_credentials.client_email = "service_account_email"
387+
mock_get_credentials.return_value = (mock_credentials, "project_id")
335388

336389
mock_sign_jwt = (
337390
mock_google_build.return_value.projects.return_value.serviceAccounts.return_value.signJwt
@@ -382,6 +435,7 @@ def mocked_json_dumps(payload):
382435
# Assert iat and exp values are as expected
383436
assert payload["iat"] == iat
384437
assert payload["exp"] == exp
438+
assert payload["sub"] == "service_account_email"
385439
assert abs(payload["exp"] - (payload["iat"] + 3600)) < 10 # Validate exp is 3600 seconds after iat
386440

387441
client.auth.gcp.login.assert_called_with(role="role", jwt="mocked_jwt", mount_point="other")
@@ -398,7 +452,10 @@ def test_gcp_dict(
398452
mock_client = mock.MagicMock()
399453
mock_hvac_client.return_value = mock_client
400454
mock_get_scopes.return_value = ["scope1", "scope2"]
401-
mock_get_credentials.return_value = ("credentials", "project_id")
455+
456+
mock_credentials = mock.MagicMock()
457+
mock_credentials.client_email = "service_account_email"
458+
mock_get_credentials.return_value = (mock_credentials, "project_id")
402459

403460
mock_sign_jwt = (
404461
mock_google_build.return_value.projects.return_value.serviceAccounts.return_value.signJwt
@@ -448,6 +505,7 @@ def mocked_json_dumps(payload):
448505
# Assert iat and exp values are as expected
449506
assert payload["iat"] == iat
450507
assert payload["exp"] == exp
508+
assert payload["sub"] == "service_account_email"
451509
assert abs(payload["exp"] - (payload["iat"] + 3600)) < 10 # Validate exp is 3600 seconds after iat
452510

453511
client.auth.gcp.login.assert_called_with(role="role", jwt="mocked_jwt")

0 commit comments

Comments
 (0)