@@ -263,7 +263,10 @@ def test_gcp(self, mock_google_build, mock_hvac_client, mock_get_credentials, mo
263263 mock_client = mock .MagicMock ()
264264 mock_hvac_client .return_value = mock_client
265265 mock_get_scopes .return_value = ["scope1" , "scope2" ]
266- mock_get_credentials .return_value = ('{"client_email": "service_account_email"}' , "project_id" )
266+
267+ mock_credentials = mock .MagicMock ()
268+ mock_credentials .client_email = "service_account_email"
269+ mock_get_credentials .return_value = (mock_credentials , "project_id" )
267270
268271 # Mock the current time to use for iat and exp
269272 current_time = int (time .time ())
@@ -315,12 +318,59 @@ def mocked_json_dumps(payload):
315318 # Assert iat and exp values are as expected
316319 assert payload ["iat" ] == iat
317320 assert payload ["exp" ] == exp
321+ assert payload ["sub" ] == "service_account_email"
318322 assert abs (payload ["exp" ] - (payload ["iat" ] + 3600 )) < 10 # Validate exp is 3600 seconds after iat
319323
320324 client .auth .gcp .login .assert_called_with (role = "role" , jwt = "mocked_jwt" )
321325 client .is_authenticated .assert_called_with ()
322326 assert vault_client .kv_engine_version == 2
323327
328+ @mock .patch ("airflow.providers.google.cloud.utils.credentials_provider._get_scopes" )
329+ @mock .patch ("airflow.providers.google.cloud.utils.credentials_provider.get_credentials_and_project_id" )
330+ @mock .patch ("airflow.providers.hashicorp._internal_client.vault_client.hvac.Client" )
331+ @mock .patch ("googleapiclient.discovery.build" )
332+ def test_gcp_adc (self , mock_google_build , mock_hvac_client , mock_get_credentials , mock_get_scopes ):
333+ mock_client = mock .MagicMock ()
334+ mock_hvac_client .return_value = mock_client
335+ mock_get_scopes .return_value = ["scope1" , "scope2" ]
336+
337+ mock_credentials = mock .MagicMock ()
338+ mock_credentials .service_account_email = "service_account_email"
339+ mock_get_credentials .return_value = (mock_credentials , "project_id" )
340+
341+ mock_sign_jwt = (
342+ mock_google_build .return_value .projects .return_value .serviceAccounts .return_value .signJwt
343+ )
344+ mock_sign_jwt .return_value .execute .return_value = {"signedJwt" : "mocked_jwt" }
345+
346+ vault_client = _VaultClient (
347+ auth_type = "gcp" ,
348+ gcp_scopes = "scope1,scope2" ,
349+ role_id = "role" ,
350+ url = "http://localhost:8180" ,
351+ session = None ,
352+ )
353+
354+ client = vault_client .client # Trigger the Vault client creation
355+
356+ # Validate that the HVAC client and other mocks are called correctly
357+ mock_hvac_client .assert_called_with (url = "http://localhost:8180" , session = None )
358+ mock_get_scopes .assert_called_with ("scope1,scope2" )
359+ mock_get_credentials .assert_called_with (
360+ key_path = None , keyfile_dict = None , scopes = ["scope1" , "scope2" ]
361+ )
362+
363+ # Extract the arguments passed to the mocked signJwt API
364+ args , kwargs = mock_sign_jwt .call_args
365+ payload = json .loads (kwargs ["body" ]["payload" ])
366+
367+ # Assert sub is correctly set to service account email
368+ assert payload ["sub" ] == "service_account_email"
369+
370+ client .auth .gcp .login .assert_called_with (role = "role" , jwt = "mocked_jwt" )
371+ client .is_authenticated .assert_called_with ()
372+ assert vault_client .kv_engine_version == 2
373+
324374 @mock .patch ("airflow.providers.google.cloud.utils.credentials_provider._get_scopes" )
325375 @mock .patch ("airflow.providers.google.cloud.utils.credentials_provider.get_credentials_and_project_id" )
326376 @mock .patch ("airflow.providers.hashicorp._internal_client.vault_client.hvac.Client" )
@@ -331,7 +381,10 @@ def test_gcp_different_auth_mount_point(
331381 mock_client = mock .MagicMock ()
332382 mock_hvac_client .return_value = mock_client
333383 mock_get_scopes .return_value = ["scope1" , "scope2" ]
334- mock_get_credentials .return_value = ('{"client_email": "service_account_email"}' , "project_id" )
384+
385+ mock_credentials = mock .MagicMock ()
386+ mock_credentials .client_email = "service_account_email"
387+ mock_get_credentials .return_value = (mock_credentials , "project_id" )
335388
336389 mock_sign_jwt = (
337390 mock_google_build .return_value .projects .return_value .serviceAccounts .return_value .signJwt
@@ -382,6 +435,7 @@ def mocked_json_dumps(payload):
382435 # Assert iat and exp values are as expected
383436 assert payload ["iat" ] == iat
384437 assert payload ["exp" ] == exp
438+ assert payload ["sub" ] == "service_account_email"
385439 assert abs (payload ["exp" ] - (payload ["iat" ] + 3600 )) < 10 # Validate exp is 3600 seconds after iat
386440
387441 client .auth .gcp .login .assert_called_with (role = "role" , jwt = "mocked_jwt" , mount_point = "other" )
@@ -398,7 +452,10 @@ def test_gcp_dict(
398452 mock_client = mock .MagicMock ()
399453 mock_hvac_client .return_value = mock_client
400454 mock_get_scopes .return_value = ["scope1" , "scope2" ]
401- mock_get_credentials .return_value = ("credentials" , "project_id" )
455+
456+ mock_credentials = mock .MagicMock ()
457+ mock_credentials .client_email = "service_account_email"
458+ mock_get_credentials .return_value = (mock_credentials , "project_id" )
402459
403460 mock_sign_jwt = (
404461 mock_google_build .return_value .projects .return_value .serviceAccounts .return_value .signJwt
@@ -448,6 +505,7 @@ def mocked_json_dumps(payload):
448505 # Assert iat and exp values are as expected
449506 assert payload ["iat" ] == iat
450507 assert payload ["exp" ] == exp
508+ assert payload ["sub" ] == "service_account_email"
451509 assert abs (payload ["exp" ] - (payload ["iat" ] + 3600 )) < 10 # Validate exp is 3600 seconds after iat
452510
453511 client .auth .gcp .login .assert_called_with (role = "role" , jwt = "mocked_jwt" )
0 commit comments