Skip to content

Commit aaf0f92

Browse files
committed
[providers] Consolidate GCP authentication in Hashicorp Vault
This commit unifies the GCP authentication logic in the Hashicorp Vault secrets backend. The 'gcp' authentication type now supports both service account key-based authentication and Application Default Credentials (ADC) automatically. Changes: - Consolidate 'gcp' authentication to handle both keys and managed identities. - Automatically determine service account email from credentials, with a fallback to the GCE/GKE metadata server. - Fix a bug where the 'sub' claim in the signed JWT payload was incorrectly set to the credentials object instead of the service account email. - Update VaultHook and VaultBackend docstrings and initialization logic. - Add unit tests for ADC-based GCP authentication and update existing tests.
1 parent 2525023 commit aaf0f92

File tree

3 files changed

+91
-38
lines changed

3 files changed

+91
-38
lines changed

providers/hashicorp/src/airflow/providers/hashicorp/_internal_client/vault_client.py

Lines changed: 22 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -171,10 +171,6 @@ def __init__(
171171
raise VaultError("The 'gcp' authentication type requires 'gcp_scopes'")
172172
if not role_id:
173173
raise VaultError("The 'gcp' authentication type requires 'role_id'")
174-
if not gcp_key_path and not gcp_keyfile_dict:
175-
raise VaultError(
176-
"The 'gcp' authentication type requires 'gcp_key_path' or 'gcp_keyfile_dict'"
177-
)
178174

179175
self.kv_engine_version = kv_engine_version or 2
180176
self.url = url
@@ -352,25 +348,35 @@ def _auth_gcp(self, _client: hvac.Client) -> None:
352348
import json
353349
import time
354350

355-
import googleapiclient
356-
357-
if self.gcp_keyfile_dict:
358-
creds = self.gcp_keyfile_dict
359-
elif self.gcp_key_path:
360-
with open(self.gcp_key_path) as f:
361-
creds = json.load(f)
362-
363-
service_account = creds["client_email"]
351+
# Determine service account email
352+
service_account_email = getattr(credentials, "service_account_email", None)
353+
if not service_account_email or not isinstance(service_account_email, str):
354+
service_account_email = getattr(credentials, "client_email", None)
355+
356+
if not service_account_email or not isinstance(service_account_email, str):
357+
# Fallback for Compute Engine credentials if email is not yet populated
358+
try:
359+
from google.auth import compute_engine
360+
if isinstance(credentials, compute_engine.Credentials):
361+
if not getattr(credentials, "service_account_email", None):
362+
from google.auth import transport
363+
credentials.refresh(transport.requests.Request())
364+
service_account_email = credentials.service_account_email
365+
except Exception:
366+
pass
367+
368+
if not service_account_email:
369+
raise VaultError("Could not determine service account email from credentials")
364370

365371
# Generate a payload for subsequent "signJwt()" call
366-
# Reference: https://googleapis.dev/python/google-auth/latest/reference/google.auth.jwt.html#google.auth.jwt.Credentials
367372
now = int(time.time())
368373
expires = now + 900 # 15 mins in seconds, can't be longer.
369-
payload = {"iat": now, "exp": expires, "sub": credentials, "aud": f"vault/{self.role_id}"}
374+
payload = {"iat": now, "exp": expires, "sub": service_account_email, "aud": f"vault/{self.role_id}"}
370375
body = {"payload": json.dumps(payload)}
371-
name = f"projects/{project_id}/serviceAccounts/{service_account}"
376+
name = f"projects/{project_id}/serviceAccounts/{service_account_email}"
372377

373378
# Perform the GCP API call
379+
import googleapiclient.discovery
374380
iam = googleapiclient.discovery.build("iam", "v1", credentials=credentials)
375381
request = iam.projects().serviceAccounts().signJwt(name=name, body=body)
376382
resp = request.execute()

providers/hashicorp/src/airflow/providers/hashicorp/hooks/vault.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ class VaultHook(BaseHook):
7979
8080
:param vault_conn_id: The id of the connection to use
8181
:param auth_type: Authentication Type for the Vault. Default is ``token``. Available values are:
82-
('approle', 'github', 'gcp', 'jwt', 'kubernetes', 'ldap', 'token', 'userpass')
82+
('approle', 'aws_iam', 'azure', 'github', 'gcp', 'jwt', 'kubernetes', 'ldap', 'radius', 'token', 'userpass')
8383
:param auth_mount_point: It can be used to define mount_point for authentication chosen
8484
Default depends on the authentication method used.
8585
:param kv_engine_version: Select the version of the engine to run (``1`` or ``2``). Defaults to
@@ -161,6 +161,10 @@ def __init__(
161161
if not region:
162162
region = self.connection.extra_dejson.get("region")
163163

164+
if auth_type == "gcp":
165+
if not role_id:
166+
role_id = self.connection.extra_dejson.get("role_id") or self.connection.login
167+
164168
azure_resource, azure_tenant_id = (
165169
self._get_azure_parameters_from_connection(azure_resource, azure_tenant_id)
166170
if auth_type == "azure"

providers/hashicorp/tests/unit/hashicorp/_internal_client/test_vault_client.py

Lines changed: 64 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -255,21 +255,18 @@ def test_azure_missing_tenant_id(self, mock_hvac):
255255
secret_id="pass",
256256
)
257257

258-
@mock.patch("builtins.open", create=True)
259258
@mock.patch("airflow.providers.google.cloud.utils.credentials_provider._get_scopes")
260259
@mock.patch("airflow.providers.google.cloud.utils.credentials_provider.get_credentials_and_project_id")
261260
@mock.patch("airflow.providers.hashicorp._internal_client.vault_client.hvac.Client")
262261
@mock.patch("googleapiclient.discovery.build")
263-
def test_gcp(self, mock_google_build, mock_hvac_client, mock_get_credentials, mock_get_scopes, mock_open):
264-
# Mock the content of the file 'path.json'
265-
mock_file = mock.MagicMock()
266-
mock_file.read.return_value = '{"client_email": "service_account_email"}'
267-
mock_open.return_value.__enter__.return_value = mock_file
268-
262+
def test_gcp(self, mock_google_build, mock_hvac_client, mock_get_credentials, mock_get_scopes):
269263
mock_client = mock.MagicMock()
270264
mock_hvac_client.return_value = mock_client
271265
mock_get_scopes.return_value = ["scope1", "scope2"]
272-
mock_get_credentials.return_value = ("credentials", "project_id")
266+
267+
mock_credentials = mock.MagicMock()
268+
mock_credentials.client_email = "service_account_email"
269+
mock_get_credentials.return_value = (mock_credentials, "project_id")
273270

274271
# Mock the current time to use for iat and exp
275272
current_time = int(time.time())
@@ -321,29 +318,73 @@ def mocked_json_dumps(payload):
321318
# Assert iat and exp values are as expected
322319
assert payload["iat"] == iat
323320
assert payload["exp"] == exp
321+
assert payload["sub"] == "service_account_email"
324322
assert abs(payload["exp"] - (payload["iat"] + 3600)) < 10 # Validate exp is 3600 seconds after iat
325323

326324
client.auth.gcp.login.assert_called_with(role="role", jwt="mocked_jwt")
327325
client.is_authenticated.assert_called_with()
328326
assert vault_client.kv_engine_version == 2
329327

330-
@mock.patch("builtins.open", create=True)
328+
@mock.patch("airflow.providers.google.cloud.utils.credentials_provider._get_scopes")
329+
@mock.patch("airflow.providers.google.cloud.utils.credentials_provider.get_credentials_and_project_id")
330+
@mock.patch("airflow.providers.hashicorp._internal_client.vault_client.hvac.Client")
331+
@mock.patch("googleapiclient.discovery.build")
332+
def test_gcp_adc(self, mock_google_build, mock_hvac_client, mock_get_credentials, mock_get_scopes):
333+
mock_client = mock.MagicMock()
334+
mock_hvac_client.return_value = mock_client
335+
mock_get_scopes.return_value = ["scope1", "scope2"]
336+
337+
mock_credentials = mock.MagicMock()
338+
mock_credentials.service_account_email = "service_account_email"
339+
mock_get_credentials.return_value = (mock_credentials, "project_id")
340+
341+
mock_sign_jwt = (
342+
mock_google_build.return_value.projects.return_value.serviceAccounts.return_value.signJwt
343+
)
344+
mock_sign_jwt.return_value.execute.return_value = {"signedJwt": "mocked_jwt"}
345+
346+
vault_client = _VaultClient(
347+
auth_type="gcp",
348+
gcp_scopes="scope1,scope2",
349+
role_id="role",
350+
url="http://localhost:8180",
351+
session=None,
352+
)
353+
354+
client = vault_client.client # Trigger the Vault client creation
355+
356+
# Validate that the HVAC client and other mocks are called correctly
357+
mock_hvac_client.assert_called_with(url="http://localhost:8180", session=None)
358+
mock_get_scopes.assert_called_with("scope1,scope2")
359+
mock_get_credentials.assert_called_with(
360+
key_path=None, keyfile_dict=None, scopes=["scope1", "scope2"]
361+
)
362+
363+
# Extract the arguments passed to the mocked signJwt API
364+
args, kwargs = mock_sign_jwt.call_args
365+
payload = json.loads(kwargs["body"]["payload"])
366+
367+
# Assert sub is correctly set to service account email
368+
assert payload["sub"] == "service_account_email"
369+
370+
client.auth.gcp.login.assert_called_with(role="role", jwt="mocked_jwt")
371+
client.is_authenticated.assert_called_with()
372+
assert vault_client.kv_engine_version == 2
373+
331374
@mock.patch("airflow.providers.google.cloud.utils.credentials_provider._get_scopes")
332375
@mock.patch("airflow.providers.google.cloud.utils.credentials_provider.get_credentials_and_project_id")
333376
@mock.patch("airflow.providers.hashicorp._internal_client.vault_client.hvac.Client")
334377
@mock.patch("googleapiclient.discovery.build")
335378
def test_gcp_different_auth_mount_point(
336-
self, mock_google_build, mock_hvac_client, mock_get_credentials, mock_get_scopes, mock_open
379+
self, mock_google_build, mock_hvac_client, mock_get_credentials, mock_get_scopes,
337380
):
338-
# Mock the content of the file 'path.json'
339-
mock_file = mock.MagicMock()
340-
mock_file.read.return_value = '{"client_email": "service_account_email"}'
341-
mock_open.return_value.__enter__.return_value = mock_file
342-
343381
mock_client = mock.MagicMock()
344382
mock_hvac_client.return_value = mock_client
345383
mock_get_scopes.return_value = ["scope1", "scope2"]
346-
mock_get_credentials.return_value = ("credentials", "project_id")
384+
385+
mock_credentials = mock.MagicMock()
386+
mock_credentials.client_email = "service_account_email"
387+
mock_get_credentials.return_value = (mock_credentials, "project_id")
347388

348389
mock_sign_jwt = (
349390
mock_google_build.return_value.projects.return_value.serviceAccounts.return_value.signJwt
@@ -394,26 +435,27 @@ def mocked_json_dumps(payload):
394435
# Assert iat and exp values are as expected
395436
assert payload["iat"] == iat
396437
assert payload["exp"] == exp
438+
assert payload["sub"] == "service_account_email"
397439
assert abs(payload["exp"] - (payload["iat"] + 3600)) < 10 # Validate exp is 3600 seconds after iat
398440

399441
client.auth.gcp.login.assert_called_with(role="role", jwt="mocked_jwt", mount_point="other")
400442
client.is_authenticated.assert_called_with()
401443
assert vault_client.kv_engine_version == 2
402444

403-
@mock.patch(
404-
"builtins.open", new_callable=mock_open, read_data='{"client_email": "service_account_email"}'
405-
)
406445
@mock.patch("airflow.providers.google.cloud.utils.credentials_provider._get_scopes")
407446
@mock.patch("airflow.providers.google.cloud.utils.credentials_provider.get_credentials_and_project_id")
408447
@mock.patch("airflow.providers.hashicorp._internal_client.vault_client.hvac.Client")
409448
@mock.patch("googleapiclient.discovery.build")
410449
def test_gcp_dict(
411-
self, mock_google_build, mock_hvac_client, mock_get_credentials, mock_get_scopes, mock_file
450+
self, mock_google_build, mock_hvac_client, mock_get_credentials, mock_get_scopes
412451
):
413452
mock_client = mock.MagicMock()
414453
mock_hvac_client.return_value = mock_client
415454
mock_get_scopes.return_value = ["scope1", "scope2"]
416-
mock_get_credentials.return_value = ("credentials", "project_id")
455+
456+
mock_credentials = mock.MagicMock()
457+
mock_credentials.client_email = "service_account_email"
458+
mock_get_credentials.return_value = (mock_credentials, "project_id")
417459

418460
mock_sign_jwt = (
419461
mock_google_build.return_value.projects.return_value.serviceAccounts.return_value.signJwt
@@ -463,6 +505,7 @@ def mocked_json_dumps(payload):
463505
# Assert iat and exp values are as expected
464506
assert payload["iat"] == iat
465507
assert payload["exp"] == exp
508+
assert payload["sub"] == "service_account_email"
466509
assert abs(payload["exp"] - (payload["iat"] + 3600)) < 10 # Validate exp is 3600 seconds after iat
467510

468511
client.auth.gcp.login.assert_called_with(role="role", jwt="mocked_jwt")

0 commit comments

Comments
 (0)