@@ -255,21 +255,18 @@ def test_azure_missing_tenant_id(self, mock_hvac):
255255 secret_id = "pass" ,
256256 )
257257
258- @mock .patch ("builtins.open" , create = True )
259258 @mock .patch ("airflow.providers.google.cloud.utils.credentials_provider._get_scopes" )
260259 @mock .patch ("airflow.providers.google.cloud.utils.credentials_provider.get_credentials_and_project_id" )
261260 @mock .patch ("airflow.providers.hashicorp._internal_client.vault_client.hvac.Client" )
262261 @mock .patch ("googleapiclient.discovery.build" )
263- def test_gcp (self , mock_google_build , mock_hvac_client , mock_get_credentials , mock_get_scopes , mock_open ):
264- # Mock the content of the file 'path.json'
265- mock_file = mock .MagicMock ()
266- mock_file .read .return_value = '{"client_email": "service_account_email"}'
267- mock_open .return_value .__enter__ .return_value = mock_file
268-
262+ def test_gcp (self , mock_google_build , mock_hvac_client , mock_get_credentials , mock_get_scopes ):
269263 mock_client = mock .MagicMock ()
270264 mock_hvac_client .return_value = mock_client
271265 mock_get_scopes .return_value = ["scope1" , "scope2" ]
272- mock_get_credentials .return_value = ("credentials" , "project_id" )
266+
267+ mock_credentials = mock .MagicMock ()
268+ mock_credentials .client_email = "service_account_email"
269+ mock_get_credentials .return_value = (mock_credentials , "project_id" )
273270
274271 # Mock the current time to use for iat and exp
275272 current_time = int (time .time ())
@@ -321,29 +318,73 @@ def mocked_json_dumps(payload):
321318 # Assert iat and exp values are as expected
322319 assert payload ["iat" ] == iat
323320 assert payload ["exp" ] == exp
321+ assert payload ["sub" ] == "service_account_email"
324322 assert abs (payload ["exp" ] - (payload ["iat" ] + 3600 )) < 10 # Validate exp is 3600 seconds after iat
325323
326324 client .auth .gcp .login .assert_called_with (role = "role" , jwt = "mocked_jwt" )
327325 client .is_authenticated .assert_called_with ()
328326 assert vault_client .kv_engine_version == 2
329327
330- @mock .patch ("builtins.open" , create = True )
328+ @mock .patch ("airflow.providers.google.cloud.utils.credentials_provider._get_scopes" )
329+ @mock .patch ("airflow.providers.google.cloud.utils.credentials_provider.get_credentials_and_project_id" )
330+ @mock .patch ("airflow.providers.hashicorp._internal_client.vault_client.hvac.Client" )
331+ @mock .patch ("googleapiclient.discovery.build" )
332+ def test_gcp_adc (self , mock_google_build , mock_hvac_client , mock_get_credentials , mock_get_scopes ):
333+ mock_client = mock .MagicMock ()
334+ mock_hvac_client .return_value = mock_client
335+ mock_get_scopes .return_value = ["scope1" , "scope2" ]
336+
337+ mock_credentials = mock .MagicMock ()
338+ mock_credentials .service_account_email = "service_account_email"
339+ mock_get_credentials .return_value = (mock_credentials , "project_id" )
340+
341+ mock_sign_jwt = (
342+ mock_google_build .return_value .projects .return_value .serviceAccounts .return_value .signJwt
343+ )
344+ mock_sign_jwt .return_value .execute .return_value = {"signedJwt" : "mocked_jwt" }
345+
346+ vault_client = _VaultClient (
347+ auth_type = "gcp" ,
348+ gcp_scopes = "scope1,scope2" ,
349+ role_id = "role" ,
350+ url = "http://localhost:8180" ,
351+ session = None ,
352+ )
353+
354+ client = vault_client .client # Trigger the Vault client creation
355+
356+ # Validate that the HVAC client and other mocks are called correctly
357+ mock_hvac_client .assert_called_with (url = "http://localhost:8180" , session = None )
358+ mock_get_scopes .assert_called_with ("scope1,scope2" )
359+ mock_get_credentials .assert_called_with (
360+ key_path = None , keyfile_dict = None , scopes = ["scope1" , "scope2" ]
361+ )
362+
363+ # Extract the arguments passed to the mocked signJwt API
364+ args , kwargs = mock_sign_jwt .call_args
365+ payload = json .loads (kwargs ["body" ]["payload" ])
366+
367+ # Assert sub is correctly set to service account email
368+ assert payload ["sub" ] == "service_account_email"
369+
370+ client .auth .gcp .login .assert_called_with (role = "role" , jwt = "mocked_jwt" )
371+ client .is_authenticated .assert_called_with ()
372+ assert vault_client .kv_engine_version == 2
373+
331374 @mock .patch ("airflow.providers.google.cloud.utils.credentials_provider._get_scopes" )
332375 @mock .patch ("airflow.providers.google.cloud.utils.credentials_provider.get_credentials_and_project_id" )
333376 @mock .patch ("airflow.providers.hashicorp._internal_client.vault_client.hvac.Client" )
334377 @mock .patch ("googleapiclient.discovery.build" )
335378 def test_gcp_different_auth_mount_point (
336- self , mock_google_build , mock_hvac_client , mock_get_credentials , mock_get_scopes , mock_open
379+ self , mock_google_build , mock_hvac_client , mock_get_credentials , mock_get_scopes ,
337380 ):
338- # Mock the content of the file 'path.json'
339- mock_file = mock .MagicMock ()
340- mock_file .read .return_value = '{"client_email": "service_account_email"}'
341- mock_open .return_value .__enter__ .return_value = mock_file
342-
343381 mock_client = mock .MagicMock ()
344382 mock_hvac_client .return_value = mock_client
345383 mock_get_scopes .return_value = ["scope1" , "scope2" ]
346- mock_get_credentials .return_value = ("credentials" , "project_id" )
384+
385+ mock_credentials = mock .MagicMock ()
386+ mock_credentials .client_email = "service_account_email"
387+ mock_get_credentials .return_value = (mock_credentials , "project_id" )
347388
348389 mock_sign_jwt = (
349390 mock_google_build .return_value .projects .return_value .serviceAccounts .return_value .signJwt
@@ -394,26 +435,27 @@ def mocked_json_dumps(payload):
394435 # Assert iat and exp values are as expected
395436 assert payload ["iat" ] == iat
396437 assert payload ["exp" ] == exp
438+ assert payload ["sub" ] == "service_account_email"
397439 assert abs (payload ["exp" ] - (payload ["iat" ] + 3600 )) < 10 # Validate exp is 3600 seconds after iat
398440
399441 client .auth .gcp .login .assert_called_with (role = "role" , jwt = "mocked_jwt" , mount_point = "other" )
400442 client .is_authenticated .assert_called_with ()
401443 assert vault_client .kv_engine_version == 2
402444
403- @mock .patch (
404- "builtins.open" , new_callable = mock_open , read_data = '{"client_email": "service_account_email"}'
405- )
406445 @mock .patch ("airflow.providers.google.cloud.utils.credentials_provider._get_scopes" )
407446 @mock .patch ("airflow.providers.google.cloud.utils.credentials_provider.get_credentials_and_project_id" )
408447 @mock .patch ("airflow.providers.hashicorp._internal_client.vault_client.hvac.Client" )
409448 @mock .patch ("googleapiclient.discovery.build" )
410449 def test_gcp_dict (
411- self , mock_google_build , mock_hvac_client , mock_get_credentials , mock_get_scopes , mock_file
450+ self , mock_google_build , mock_hvac_client , mock_get_credentials , mock_get_scopes
412451 ):
413452 mock_client = mock .MagicMock ()
414453 mock_hvac_client .return_value = mock_client
415454 mock_get_scopes .return_value = ["scope1" , "scope2" ]
416- mock_get_credentials .return_value = ("credentials" , "project_id" )
455+
456+ mock_credentials = mock .MagicMock ()
457+ mock_credentials .client_email = "service_account_email"
458+ mock_get_credentials .return_value = (mock_credentials , "project_id" )
417459
418460 mock_sign_jwt = (
419461 mock_google_build .return_value .projects .return_value .serviceAccounts .return_value .signJwt
@@ -463,6 +505,7 @@ def mocked_json_dumps(payload):
463505 # Assert iat and exp values are as expected
464506 assert payload ["iat" ] == iat
465507 assert payload ["exp" ] == exp
508+ assert payload ["sub" ] == "service_account_email"
466509 assert abs (payload ["exp" ] - (payload ["iat" ] + 3600 )) < 10 # Validate exp is 3600 seconds after iat
467510
468511 client .auth .gcp .login .assert_called_with (role = "role" , jwt = "mocked_jwt" )
0 commit comments