diff --git a/docs/index.rst b/docs/index.rst index e608fe6b..70bd5786 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -130,3 +130,24 @@ See `SerializableTokenCache` for example. .. autoclass:: msal.SerializableTokenCache :members: + + +Managed Identity +================ +MSAL supports +`Managed Identity `_. + +You can create one of these two kinds of managed identity configuration objects: + +.. autoclass:: msal.SystemAssignedManagedIdentity + :members: + +.. autoclass:: msal.UserAssignedManagedIdentity + :members: + +And then feed the configuration object into a :class:`ManagedIdentityClient` object. + +.. autoclass:: msal.ManagedIdentityClient + :members: + + .. automethod:: __init__ diff --git a/msal/__init__.py b/msal/__init__.py index 4e2faaed..d15c3a6c 100644 --- a/msal/__init__.py +++ b/msal/__init__.py @@ -33,4 +33,8 @@ ) from .oauth2cli.oidc import Prompt from .token_cache import TokenCache, SerializableTokenCache +from .managed_identity import ( + SystemAssignedManagedIdentity, UserAssignedManagedIdentity, + ManagedIdentityClient, + ) diff --git a/msal/application.py b/msal/application.py index 48b6575b..c5ed1991 100644 --- a/msal/application.py +++ b/msal/application.py @@ -22,6 +22,7 @@ from .region import _detect_region from .throttled_http_client import ThrottledHttpClient from .cloudshell import _is_running_in_cloud_shell +from .imds import ManagedIdentityClient, ManagedIdentity, _scope_to_resource # The __init__.py will import this. Not the other way around. @@ -2021,6 +2022,14 @@ def acquire_token_for_client(self, scopes, claims_challenge=None, **kwargs): - an error response would contain "error" and usually "error_description". """ # TBD: force_refresh behavior + if ManagedIdentity.is_managed_identity(self.client_id): + if len(scopes) != 1: + raise ValueError("Managed Identity supports only one scope/resource") + if claims_challenge: + raise ValueError("Managed Identity does not support claims_challenge") + return ManagedIdentityClient( + self.http_client, self.client_id, self.token_cache + ).acquire_token(_scope_to_resource(scopes[0])) if self.authority.tenant.lower() in ["common", "organizations"]: warnings.warn( "Using /common or /organizations authority " diff --git a/msal/managed_identity.py b/msal/managed_identity.py new file mode 100644 index 00000000..fa2a7221 --- /dev/null +++ b/msal/managed_identity.py @@ -0,0 +1,436 @@ +# Copyright (c) Microsoft Corporation. +# All rights reserved. +# +# This code is licensed under the MIT License. +import json +import logging +import os +import socket +import time +try: # Python 2 + from urlparse import urlparse +except: # Python 3 + from urllib.parse import urlparse +try: # Python 3 + from collections import UserDict +except: + UserDict = dict # The real UserDict is an old-style class which fails super() +from .token_cache import TokenCache +from .throttled_http_client import ThrottledHttpClient + + +logger = logging.getLogger(__name__) + +class ManagedIdentity(UserDict): + """Feed an instance of this class to :class:`msal.ManagedIdentityClient` + to acquire token for the specified managed identity. + """ + # The key names used in config dict + ID_TYPE = "ManagedIdentityIdType" # Contains keyword ManagedIdentity so its json equivalent will be more readable + ID = "Id" + + # Valid values for key ID_TYPE + CLIENT_ID = "ClientId" + RESOURCE_ID = "ResourceId" + OBJECT_ID = "ObjectId" + SYSTEM_ASSIGNED = "SystemAssigned" + + _types_mapping = { # Maps type name in configuration to type name on wire + CLIENT_ID: "client_id", + RESOURCE_ID: "mi_res_id", + OBJECT_ID: "object_id", + } + + @classmethod + def is_managed_identity(cls, unknown): + return isinstance(unknown, dict) and cls.ID_TYPE in unknown + + @classmethod + def is_system_assigned(cls, unknown): + return isinstance(unknown, dict) and unknown.get(cls.ID_TYPE) == cls.SYSTEM_ASSIGNED + + @classmethod + def is_user_assigned(cls, unknown): + return ( + isinstance(unknown, dict) + and unknown.get(cls.ID_TYPE) in cls._types_mapping + and unknown.get(cls.ID)) + + def __init__(self, identifier=None, id_type=None): + # Undocumented. Use subclasses instead. + super(ManagedIdentity, self).__init__({ + self.ID_TYPE: id_type, + self.ID: identifier, + }) + + +class SystemAssignedManagedIdentity(ManagedIdentity): + """Represent a system-assigned managed identity, which is equivalent to:: + + {"ManagedIdentityIdType": "SystemAssigned", "Id": None} + """ + def __init__(self): + super(SystemAssignedManagedIdentity, self).__init__(id_type=self.SYSTEM_ASSIGNED) + + +class UserAssignedManagedIdentity(ManagedIdentity): + """Represent a user-assigned managed identity. + + Depends on the id you provided, the outcome is equivalent to one of the below:: + + {"ManagedIdentityIdType": "ClientId", "Id": "foo"} + {"ManagedIdentityIdType": "ResourceId", "Id": "foo"} + {"ManagedIdentityIdType": "ObjectId", "Id": "foo"} + """ + def __init__(self, client_id=None, resource_id=None, object_id=None): + if client_id and not resource_id and not object_id: + super(UserAssignedManagedIdentity, self).__init__( + id_type=self.CLIENT_ID, identifier=client_id) + elif not client_id and resource_id and not object_id: + super(UserAssignedManagedIdentity, self).__init__( + id_type=self.RESOURCE_ID, identifier=resource_id) + elif not client_id and not resource_id and object_id: + super(UserAssignedManagedIdentity, self).__init__( + id_type=self.OBJECT_ID, identifier=object_id) + else: + raise ValueError( + "You shall specify one of the three parameters: " + "client_id, resource_id, object_id") + + +class ManagedIdentityClient(object): + """This API encapulates multiple managed identity backends: + VM, App Service, Azure Automation (Runbooks), Azure Function, Service Fabric, + and Azure Arc. + + It also provides token cache support. + """ + _instance, _tenant = socket.getfqdn(), "managed_identity" # Placeholders + + def __init__(self, managed_identity, http_client, token_cache=None): + """Create a managed identity client. + + :param dict managed_identity: + It accepts an instance of :class:`SystemAssignedManagedIdentity` + or :class:`UserAssignedManagedIdentity`. + They are equivalent to a dict with a certain shape, + which may be loaded from a json configuration file or an env var, + + :param http_client: + An http client object. For example, you can use ``requests.Session()``, + optionally with exponential backoff behavior demonstrated in this recipe:: + + import requests + from requests.adapters import HTTPAdapter, Retry + s = requests.Session() + retries = Retry(total=5, backoff_factor=1, status_forcelist=[502, 503, 504]) + s.mount('https://', HTTPAdapter(max_retries=retries)) + client = ManagedIdentityClient(managed_identity, s) + + :param token_cache: + Optional. It accepts a :class:`msal.TokenCache` instance to store tokens. + It will use an in-memory token cache by default. + + Recipe 1: Hard code a managed identity for your app:: + + import msal, requests + client = msal.ManagedIdentityClient( + msal.UserAssignedManagedIdentity(client_id="foo"), + requests.Session(), + ) + token = client.acquire_token_for_client("resource") + + Recipe 2: Write once, run everywhere. + If you use different managed identity on different deployment, + you may use an environment variable (such as AZURE_MANAGED_IDENTITY) + to store a json blob like + ``{"ManagedIdentityIdType": "ClientId", "Id": "foo"}`` or + ``{"ManagedIdentityIdType": "SystemAssignedManagedIdentity", "Id": null})``. + The following app can load managed identity configuration dynamically:: + + import json, os, msal, requests + config = os.getenv("AZURE_MANAGED_IDENTITY") + assert config, "An ENV VAR with value should exist" + client = msal.ManagedIdentityClient( + json.loads(config), + requests.Session(), + ) + token = client.acquire_token_for_client("resource") + """ + self._managed_identity = managed_identity + if isinstance(http_client, ThrottledHttpClient): + raise ValueError( + # It is a precaution to reject application.py's throttled http_client, + # whose cache life on HTTP GET 200 is too long for Managed Identity. + "This class does not currently accept a ThrottledHttpClient.") + self._http_client = http_client + self._token_cache = token_cache or TokenCache() + + def acquire_token_for_client(self, resource=None): + """Acquire token for the managed identity. + + The result will be automatically cached. + """ + if not resource: + raise ValueError( + "The resource parameter is currently required. " + "It is only declared as optional in method signature, " + "in case we want to support scope parameter in the future.") + access_token_from_cache = None + client_id_in_cache = self._managed_identity.get( + ManagedIdentity.ID, "SYSTEM_ASSIGNED_MANAGED_IDENTITY") + if True: # Does not offer an "if not force_refresh" option, because + # there would be built-in token cache in the service side anyway + matches = self._token_cache.find( + self._token_cache.CredentialType.ACCESS_TOKEN, + target=[resource], + query=dict( + client_id=client_id_in_cache, + environment=self._instance, + realm=self._tenant, + home_account_id=None, + ), + ) + now = time.time() + for entry in matches: + expires_in = int(entry["expires_on"]) - now + if expires_in < 5*60: # Then consider it expired + continue # Removal is not necessary, it will be overwritten + logger.debug("Cache hit an AT") + access_token_from_cache = { # Mimic a real response + "access_token": entry["secret"], + "token_type": entry.get("token_type", "Bearer"), + "expires_in": int(expires_in), # OAuth2 specs defines it as int + } + if "refresh_on" in entry and int(entry["refresh_on"]) < now: # aging + break # With a fallback in hand, we break here to go refresh + return access_token_from_cache # It is still good as new + try: + result = _obtain_token(self._http_client, self._managed_identity, resource) + if "access_token" in result: + expires_in = result.get("expires_in", 3600) + if "refresh_in" not in result and expires_in >= 7200: + result["refresh_in"] = int(expires_in / 2) + self._token_cache.add(dict( + client_id=client_id_in_cache, + scope=[resource], + token_endpoint="https://{}/{}".format(self._instance, self._tenant), + response=result, + params={}, + data={}, + )) + if (result and "error" not in result) or (not access_token_from_cache): + return result + except: # The exact HTTP exception is transportation-layer dependent + # Typically network error. Potential AAD outage? + if not access_token_from_cache: # It means there is no fall back option + raise # We choose to bubble up the exception + return access_token_from_cache + + +def _scope_to_resource(scope): # This is an experimental reasonable-effort approach + u = urlparse(scope) + if u.scheme: + return "{}://{}".format(u.scheme, u.netloc) + return scope # There is no much else we can do here + + +def _obtain_token(http_client, managed_identity, resource): + # A unified low-level API that talks to different Managed Identity + if ("IDENTITY_ENDPOINT" in os.environ and "IDENTITY_HEADER" in os.environ + and "IDENTITY_SERVER_THUMBPRINT" in os.environ + ): + if managed_identity: + logger.debug( + "Ignoring managed_identity parameter. " + "Managed Identity in Service Fabric is configured in the cluster, " + "not during runtime. See also " + "https://learn.microsoft.com/en-us/azure/service-fabric/configure-existing-cluster-enable-managed-identity-token-service") + return _obtain_token_on_service_fabric( + http_client, + os.environ["IDENTITY_ENDPOINT"], + os.environ["IDENTITY_HEADER"], + os.environ["IDENTITY_SERVER_THUMBPRINT"], + resource, + ) + if "IDENTITY_ENDPOINT" in os.environ and "IDENTITY_HEADER" in os.environ: + return _obtain_token_on_app_service( + http_client, + os.environ["IDENTITY_ENDPOINT"], + os.environ["IDENTITY_HEADER"], + managed_identity, + resource, + ) + if "IDENTITY_ENDPOINT" in os.environ and "IMDS_ENDPOINT" in os.environ: + if ManagedIdentity.is_user_assigned(managed_identity): + raise ValueError( # Note: Azure Identity for Python raised exception too + "Ignoring managed_identity parameter. " + "Azure Arc supports only system-assigned managed identity, " + "See also " + "https://learn.microsoft.com/en-us/azure/service-fabric/configure-existing-cluster-enable-managed-identity-token-service") + return _obtain_token_on_arc( + http_client, + os.environ["IDENTITY_ENDPOINT"], + resource, + ) + return _obtain_token_on_azure_vm(http_client, managed_identity, resource) + + +def _adjust_param(params, managed_identity): + id_name = ManagedIdentity._types_mapping.get( + managed_identity.get(ManagedIdentity.ID_TYPE)) + if id_name: + params[id_name] = managed_identity[ManagedIdentity.ID] + +def _obtain_token_on_azure_vm(http_client, managed_identity, resource): + # Based on https://docs.microsoft.com/en-us/azure/active-directory/managed-identities-azure-resources/how-to-use-vm-token#get-a-token-using-http + logger.debug("Obtaining token via managed identity on Azure VM") + params = { + "api-version": "2018-02-01", + "resource": resource, + } + _adjust_param(params, managed_identity) + resp = http_client.get( + "http://169.254.169.254/metadata/identity/oauth2/token", + params=params, + headers={"Metadata": "true"}, + ) + try: + payload = json.loads(resp.text) + if payload.get("access_token") and payload.get("expires_in"): + return { # Normalizing the payload into OAuth2 format + "access_token": payload["access_token"], + "expires_in": int(payload["expires_in"]), + "resource": payload.get("resource"), + "token_type": payload.get("token_type", "Bearer"), + } + return payload # Typically an error, but it is undefined in the doc above + except ValueError: + logger.debug("IMDS emits unexpected payload: %s", resp.text) + raise + +def _obtain_token_on_app_service( + http_client, endpoint, identity_header, managed_identity, resource, +): + """Obtains token for + `App Service `_, + Azure Functions, and Azure Automation. + """ + # Prerequisite: Create your app service https://docs.microsoft.com/en-us/azure/app-service/quickstart-python + # Assign it a managed identity https://docs.microsoft.com/en-us/azure/app-service/overview-managed-identity?tabs=portal%2Chttp + # SSH into your container for testing https://docs.microsoft.com/en-us/azure/app-service/configure-linux-open-ssh-session + logger.debug("Obtaining token via managed identity on Azure App Service") + params = { + "api-version": "2019-08-01", + "resource": resource, + } + _adjust_param(params, managed_identity) + resp = http_client.get( + endpoint, + params=params, + headers={ + "X-IDENTITY-HEADER": identity_header, + "Metadata": "true", # Unnecessary yet harmless for App Service, + # It will be needed by Azure Automation + # https://docs.microsoft.com/en-us/azure/automation/enable-managed-identity-for-automation#get-access-token-for-system-assigned-managed-identity-using-http-get + }, + ) + try: + payload = json.loads(resp.text) + if payload.get("access_token") and payload.get("expires_on"): + return { # Normalizing the payload into OAuth2 format + "access_token": payload["access_token"], + "expires_in": int(payload["expires_on"]) - int(time.time()), + "resource": payload.get("resource"), + "token_type": payload.get("token_type", "Bearer"), + } + return { + "error": "invalid_scope", # Empirically, wrong resource ends up with a vague statusCode=500 + "error_description": "{}, {}".format( + payload.get("statusCode"), payload.get("message")), + } + except ValueError: + logger.debug("IMDS emits unexpected payload: %s", resp.text) + raise + + +def _obtain_token_on_service_fabric( + http_client, endpoint, identity_header, server_thumbprint, resource, +): + """Obtains token for + `Service Fabric `_ + """ + # Deployment https://learn.microsoft.com/en-us/azure/service-fabric/service-fabric-get-started-containers-linux + # See also https://github.com/Azure/azure-sdk-for-python/blob/main/sdk/identity/azure-identity/tests/managed-identity-live/service-fabric/service_fabric.md + # Protocol https://learn.microsoft.com/en-us/azure/service-fabric/how-to-managed-identity-service-fabric-app-code#acquiring-an-access-token-using-rest-api + logger.debug("Obtaining token via managed identity on Azure Service Fabric") + resp = http_client.get( + endpoint, + params={"api-version": "2019-07-01-preview", "resource": resource}, + headers={"Secret": identity_header}, + ) + try: + payload = json.loads(resp.text) + if payload.get("access_token") and payload.get("expires_on"): + return { # Normalizing the payload into OAuth2 format + "access_token": payload["access_token"], + "expires_in": payload["expires_on"] - int(time.time()), + "resource": payload.get("resource"), + "token_type": payload["token_type"], + } + error = payload.get("error", {}) # https://learn.microsoft.com/en-us/azure/service-fabric/how-to-managed-identity-service-fabric-app-code#error-handling + error_mapping = { # Map Service Fabric errors into OAuth2 errors https://www.rfc-editor.org/rfc/rfc6749#section-5.2 + "SecretHeaderNotFound": "unauthorized_client", + "ManagedIdentityNotFound": "invalid_client", + "ArgumentNullOrEmpty": "invalid_scope", + } + return { + "error": error_mapping.get(payload["error"]["code"], "invalid_request"), + "error_description": resp.text, + } + except ValueError: + logger.debug("IMDS emits unexpected payload: %s", resp.text) + raise + + +def _obtain_token_on_arc(http_client, endpoint, resource): + # https://learn.microsoft.com/en-us/azure/azure-arc/servers/managed-identity-authentication + logger.debug("Obtaining token via managed identity on Azure Arc") + resp = http_client.get( + endpoint, + params={"api-version": "2020-06-01", "resource": resource}, + headers={"Metadata": "true"}, + ) + www_auth = "www-authenticate" # Header in lower case + challenge = { + # Normalized to lowercase, because header names are case-insensitive + # https://datatracker.ietf.org/doc/html/rfc7230#section-3.2 + k.lower(): v for k, v in resp.headers.items() if k.lower() == www_auth + }.get(www_auth, "").split("=") # Output will be ["Basic realm", "content"] + if not ( # https://datatracker.ietf.org/doc/html/rfc7617#section-2 + len(challenge) == 2 and challenge[0].lower() == "basic realm"): + raise ValueError("Irrecognizable WWW-Authenticate header: {}".format(resp.headers)) + with open(challenge[1]) as f: + secret = f.read() + response = http_client.get( + endpoint, + params={"api-version": "2020-06-01", "resource": resource}, + headers={"Metadata": "true", "Authorization": "Basic {}".format(secret)}, + ) + try: + payload = json.loads(response.text) + if payload.get("access_token") and payload.get("expires_in"): + # Example: https://learn.microsoft.com/en-us/azure/azure-arc/servers/media/managed-identity-authentication/bash-token-output-example.png + return { + "access_token": payload["access_token"], + "expires_in": int(payload["expires_in"]), + "token_type": payload.get("token_type", "Bearer"), + "resource": payload.get("resource"), + } + except ValueError: # Typically json.decoder.JSONDecodeError + pass + return { + "error": "invalid_request", + "error_description": response.text, + } + diff --git a/tests/msaltest.py b/tests/msaltest.py index b1556106..aca539f5 100644 --- a/tests/msaltest.py +++ b/tests/msaltest.py @@ -1,4 +1,4 @@ -import getpass, logging, pprint, sys, msal +import functools, getpass, logging, pprint, sys, requests, msal AZURE_CLI = "04b07795-8ddb-461a-bbee-02f9e1bf7b46" @@ -47,8 +47,11 @@ def _input_scopes(): raise ValueError("SSH Cert scope shall be tested by its dedicated functions") return scopes -def _select_account(app): +def _select_account(app, show_confidential_app_placeholder=False): accounts = app.get_accounts() + if show_confidential_app_placeholder and isinstance( + app, msal.ConfidentialClientApplication): + accounts.insert(0, {"username": "This Client"}) if accounts: return _select_options( accounts, @@ -60,11 +63,11 @@ def _select_account(app): def acquire_token_silent(app): """acquire_token_silent() - with an account already signed into MSAL Python.""" - account = _select_account(app) + account = _select_account(app, show_confidential_app_placeholder=True) if account: pprint.pprint(app.acquire_token_silent( _input_scopes(), - account=account, + account=account if "home_account_id" in account else None, force_refresh=_input_boolean("Bypass MSAL Python's token cache?"), )) @@ -138,29 +141,57 @@ def remove_account(app): app.remove_account(account) print('Account "{}" and/or its token(s) are signed out from MSAL Python'.format(account["username"])) +def acquire_token_for_managed_identity(app): + """acquire_token() - Only for managed identity""" + pprint.pprint(app.acquire_token(_select_options([ + "https://management.azure.com", + "https://graph.microsoft.com", + ], + header="Acquire token for this resource", + accept_nonempty_string=True))) + def exit(app): """Exit""" bug_link = ( "https://identitydivision.visualstudio.com/Engineering/_queries/query/79b3a352-a775-406f-87cd-a487c382a8ed/" - if app._enable_broker else + if getattr(app, "_enable_broker", None) else "https://github.com/AzureAD/microsoft-authentication-library-for-python/issues/new/choose" ) print("Bye. If you found a bug, please report it here: {}".format(bug_link)) sys.exit() +def _managed_identity(): + mi = _select_options([ + { + 'ManagedIdentityIdType': 'SystemAssignedManagedIdentity', + "name": "System-assigned managed identity", + }], + option_renderer=lambda a: a["name"], + header="Choose the system-assigned managed identity " + "(or type in your user-assigned managed identity's client id)", + accept_nonempty_string=True) + return msal.ManagedIdentityClient( + requests.Session(), + mi if isinstance(mi, dict) else msal.UserAssignedManagedIdentity( + identifier=mi, id_type=msal.UserAssignedManagedIdentity.CLIENT_ID), + token_cache=msal.TokenCache(), + ) + def main(): - print("Welcome to the Msal Python Console Test App, committed at 2022-5-2\n") + print("Welcome to the Console Test App for MSAL Python {}\n".format(msal.__version__)) chosen_app = _select_options([ {"client_id": AZURE_CLI, "name": "Azure CLI (Correctly configured for MSA-PT)"}, {"client_id": VISUAL_STUDIO, "name": "Visual Studio (Correctly configured for MSA-PT)"}, {"client_id": "95de633a-083e-42f5-b444-a4295d8e9314", "name": "Whiteboard Services (Non MSA-PT app. Accepts AAD & MSA accounts.)"}, + {"test_managed_identity": None, "name": "Managed Identity (Only works when running inside a supported environment, such as Azure VM, Azure App Service, Azure Automation)"}, ], option_renderer=lambda a: a["name"], header="Impersonate this app (or you can type in the client_id of your own app)", accept_nonempty_string=True) - app = msal.PublicClientApplication( - chosen_app["client_id"] if isinstance(chosen_app, dict) else chosen_app, - authority=_select_options([ + if isinstance(chosen_app, dict) and "test_managed_identity" in chosen_app: + app = _managed_identity() + else: + authority = _select_options([ "https://login.microsoftonline.com/common", "https://login.microsoftonline.com/organizations", "https://login.microsoftonline.com/microsoft.onmicrosoft.com", @@ -169,21 +200,32 @@ def main(): ], header="Input authority (Note that MSA-PT apps would NOT use the /common authority)", accept_nonempty_string=True, - ), - allow_broker=_input_boolean("Allow broker? (Azure CLI currently only supports @microsoft.com accounts when enabling broker)"), ) + app = msal.PublicClientApplication( + chosen_app["client_id"] if isinstance(chosen_app, dict) else chosen_app, + authority=authority, + allow_broker=_input_boolean("Allow broker? (Azure CLI currently only supports @microsoft.com accounts when enabling broker)"), + ) if _input_boolean("Enable MSAL Python's DEBUG log?"): logging.basicConfig(level=logging.DEBUG) + methods_to_be_tested = functools.reduce(lambda x, y: x + y, [ + methods for app_type, methods in { + msal.PublicClientApplication: [ + acquire_token_interactive, + acquire_ssh_cert_silently, + acquire_ssh_cert_interactive, + ], + msal.ClientApplication: [ + acquire_token_silent, + acquire_token_by_username_password, + remove_account, + ], + msal.ManagedIdentityClient: [acquire_token_for_managed_identity], + }.items() if isinstance(app, app_type)]) while True: - func = _select_options([ - acquire_token_silent, - acquire_token_interactive, - acquire_token_by_username_password, - acquire_ssh_cert_silently, - acquire_ssh_cert_interactive, - remove_account, - exit, - ], option_renderer=lambda f: f.__doc__, header="MSAL Python APIs:") + func = _select_options( + methods_to_be_tested + [exit], + option_renderer=lambda f: f.__doc__, header="MSAL Python APIs:") try: func(app) except ValueError as e: diff --git a/tests/test_mi.py b/tests/test_mi.py new file mode 100644 index 00000000..2fc8f374 --- /dev/null +++ b/tests/test_mi.py @@ -0,0 +1,192 @@ +import json +import os +import sys +import time +import unittest +try: + from unittest.mock import patch, ANY, mock_open +except: + from mock import patch, ANY, mock_open +import requests + +from tests.http_client import MinimalResponse +from msal import ( + ConfidentialClientApplication, + SystemAssignedManagedIdentity, UserAssignedManagedIdentity, + ) + + +class ManagedIdentityTestCase(unittest.TestCase): + def test_helper_class_should_be_interchangable_with_dict_which_could_be_loaded_from_file_or_env_var(self): + self.assertEqual( + UserAssignedManagedIdentity(client_id="foo"), + {"ManagedIdentityIdType": "ClientId", "Id": "foo"}) + self.assertEqual( + UserAssignedManagedIdentity(resource_id="foo"), + {"ManagedIdentityIdType": "ResourceId", "Id": "foo"}) + self.assertEqual( + UserAssignedManagedIdentity(object_id="foo"), + {"ManagedIdentityIdType": "ObjectId", "Id": "foo"}) + with self.assertRaises(ValueError): + UserAssignedManagedIdentity() + with self.assertRaises(ValueError): + UserAssignedManagedIdentity(client_id="foo", resource_id="bar") + self.assertEqual( + SystemAssignedManagedIdentity(), + {"ManagedIdentityIdType": "SystemAssigned", "Id": None}) + + +class ClientTestCase(unittest.TestCase): + maxDiff = None + + def setUp(self): + system_assigned = {"ManagedIdentityIdType": "SystemAssigned", "Id": None} + self.app = ConfidentialClientApplication(client_id=system_assigned) + + def _test_token_cache(self, app): + cache = app.token_cache._cache + self.assertEqual(1, len(cache.get("AccessToken", [])), "Should have 1 AT") + at = list(cache["AccessToken"].values())[0] + self.assertEqual( + app.client_id.get("Id", "SYSTEM_ASSIGNED_MANAGED_IDENTITY"), + at["client_id"], + "Should have expected client_id") + self.assertEqual("managed_identity", at["realm"], "Should have expected realm") + + def _test_happy_path(self, app, mocked_http): + #result = app.acquire_token_for_client(resource="R") + result = app.acquire_token_for_client(["R"]) + mocked_http.assert_called() + self.assertEqual({ + "access_token": "AT", + "expires_in": 1234, + "resource": "R", + "token_type": "Bearer", + }, result, "Should obtain a token response") + self.assertEqual( + result["access_token"], + app.acquire_token_for_client(["R"]).get("access_token"), + "Should hit the same token from cache") + self._test_token_cache(app) + + +class VmTestCase(ClientTestCase): + + def test_happy_path(self): + with patch.object(self.app.http_client, "get", return_value=MinimalResponse( + status_code=200, + text='{"access_token": "AT", "expires_in": "1234", "resource": "R"}', + )) as mocked_method: + self._test_happy_path(self.app, mocked_method) + + def test_vm_error_should_be_returned_as_is(self): + raw_error = '{"raw": "error format is undefined"}' + with patch.object(self.app.http_client, "get", return_value=MinimalResponse( + status_code=400, + text=raw_error, + )) as mocked_method: + self.assertEqual( + json.loads(raw_error), self.app.acquire_token_for_client(["R"])) + self.assertEqual({}, self.app.token_cache._cache) + + +@patch.dict(os.environ, {"IDENTITY_ENDPOINT": "http://localhost", "IDENTITY_HEADER": "foo"}) +class AppServiceTestCase(ClientTestCase): + + def test_happy_path(self): + with patch.object(self.app._http_client, "get", return_value=MinimalResponse( + status_code=200, + text='{"access_token": "AT", "expires_on": "%s", "resource": "R"}' % ( + int(time.time()) + 1234), + )) as mocked_method: + self._test_happy_path(self.app, mocked_method) + + def test_app_service_error_should_be_normalized(self): + raw_error = '{"statusCode": 500, "message": "error content is undefined"}' + with patch.object(self.app._http_client, "get", return_value=MinimalResponse( + status_code=500, + text=raw_error, + )) as mocked_method: + self.assertEqual({ + "error": "invalid_scope", + "error_description": "500, error content is undefined", + }, self.app.acquire_token_for_client(resource="R")) + self.assertEqual({}, self.app._token_cache._cache) + + +@patch.dict(os.environ, { + "IDENTITY_ENDPOINT": "http://localhost", + "IDENTITY_HEADER": "foo", + "IDENTITY_SERVER_THUMBPRINT": "bar", +}) +class ServiceFabricTestCase(ClientTestCase): + + def _test_happy_path(self, app): + with patch.object(app._http_client, "get", return_value=MinimalResponse( + status_code=200, + text='{"access_token": "AT", "expires_on": %s, "resource": "R", "token_type": "Bearer"}' % ( + int(time.time()) + 1234), + )) as mocked_method: + super(ServiceFabricTestCase, self)._test_happy_path(app, mocked_method) + + def test_happy_path(self): + self._test_happy_path(self.app) + + def test_unified_api_service_should_ignore_unnecessary_client_id(self): + self._test_happy_path(ManagedIdentityClient( + {"ManagedIdentityIdType": "ClientId", "Id": "foo"}, + requests.Session(), + )) + + def test_sf_error_should_be_normalized(self): + raw_error = ''' +{"error": { + "correlationId": "foo", + "code": "SecretHeaderNotFound", + "message": "Secret is not found in the request headers." +}}''' # https://learn.microsoft.com/en-us/azure/service-fabric/how-to-managed-identity-service-fabric-app-code#error-handling + with patch.object(self.app._http_client, "get", return_value=MinimalResponse( + status_code=404, + text=raw_error, + )) as mocked_method: + self.assertEqual({ + "error": "unauthorized_client", + "error_description": raw_error, + }, self.app.acquire_token_for_client(resource="R")) + self.assertEqual({}, self.app._token_cache._cache) + + +@patch.dict(os.environ, { + "IDENTITY_ENDPOINT": "http://localhost/token", + "IMDS_ENDPOINT": "http://localhost", +}) +@patch( + "builtins.open" if sys.version_info.major >= 3 else "__builtin__.open", + mock_open(read_data="secret") +) +class ArcTestCase(ClientTestCase): + challenge = MinimalResponse(status_code=401, text="", headers={ + "WWW-Authenticate": "Basic realm=/tmp/foo", + }) + + def test_happy_path(self): + with patch.object(self.app._http_client, "get", side_effect=[ + self.challenge, + MinimalResponse( + status_code=200, + text='{"access_token": "AT", "expires_in": "1234", "resource": "R"}', + ), + ]) as mocked_method: + super(ArcTestCase, self)._test_happy_path(self.app, mocked_method) + + def test_arc_error_should_be_normalized(self): + with patch.object(self.app._http_client, "get", side_effect=[ + self.challenge, + MinimalResponse(status_code=400, text="undefined"), + ]) as mocked_method: + self.assertEqual({ + "error": "invalid_request", + "error_description": "undefined", + }, self.app.acquire_token_for_client(resource="R")) + self.assertEqual({}, self.app._token_cache._cache) +