-
Notifications
You must be signed in to change notification settings - Fork 7
APP-8642 : Add OAuth to pyatlan #767
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
6dab7aa
ffee571
088ba3b
64027a4
eaf6dd7
c7903dc
c383729
c397fb2
ec8c0ec
e7bfefd
88c909a
0ff1ca6
9c08e1f
519fafb
899363b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,2 +1,8 @@ | ||
| ATLAN_BASE_URL=your_tenant_base_url | ||
|
|
||
| #API KEY based authentication | ||
| ATLAN_API_KEY=your_api_key | ||
|
|
||
| #OAuth based authentication | ||
| ATLAN_OAUTH_CLIENT_ID=your_oauth_client_id | ||
| ATLAN_OAUTH_CLIENT_SECRET=your_oauth_client_secret | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,91 @@ | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
| # Copyright 2025 Atlan Pte. Ltd. | ||
| import asyncio | ||
| import time | ||
| from typing import Optional | ||
|
|
||
| import httpx | ||
| from authlib.oauth2.rfc6749 import OAuth2Token | ||
|
|
||
| from pyatlan.client.constants import GET_OAUTH_CLIENT | ||
| from pyatlan.utils import API | ||
|
|
||
|
|
||
| class AsyncOAuthTokenManager: | ||
| def __init__( | ||
| self, | ||
| base_url: str, | ||
| client_id: str, | ||
| client_secret: str, | ||
| http_client: Optional[httpx.AsyncClient] = None, | ||
| connect_timeout: float = 30.0, | ||
| read_timeout: float = 900.0, | ||
| ): | ||
| self.base_url = base_url | ||
| self.client_id = client_id | ||
| self.client_secret = client_secret | ||
| self.token_url = self._create_path(GET_OAUTH_CLIENT) | ||
| self._lock = asyncio.Lock() | ||
| self._http_client = http_client or httpx.AsyncClient( | ||
| timeout=httpx.Timeout( | ||
| connect=connect_timeout, read=read_timeout, write=30.0, pool=30.0 | ||
| ) | ||
| ) | ||
| self._token: Optional[OAuth2Token] = None | ||
| self._owns_client = http_client is None | ||
|
|
||
| async def get_token(self) -> str: | ||
| async with self._lock: | ||
| if self._token and not self._token.is_expired(): | ||
| return str(self._token["access_token"]) | ||
|
|
||
| response = await self._http_client.post( | ||
| self.token_url, | ||
| json={ | ||
| "clientId": self.client_id, | ||
| "clientSecret": self.client_secret, | ||
| }, | ||
| headers={"Content-Type": "application/json"}, | ||
| ) | ||
| response.raise_for_status() | ||
|
|
||
| data = response.json() | ||
| access_token = data.get("accessToken") or data.get("access_token") | ||
|
|
||
| if not access_token: | ||
| raise ValueError( | ||
| f"OAuth token response missing 'accessToken' field. " | ||
| f"Response keys: {list(data.keys())}" | ||
| ) | ||
|
|
||
| expires_in = data.get("expiresIn") or data.get("expires_in", 600) | ||
vaibhavatlan marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| self._token = OAuth2Token( | ||
| { | ||
| "access_token": access_token, | ||
| "token_type": data.get("tokenType") | ||
| or data.get("token_type", "Bearer"), | ||
| "expires_in": expires_in, | ||
| "expires_at": int(time.time()) + expires_in, | ||
| } | ||
| ) | ||
|
|
||
| return access_token | ||
|
|
||
| async def invalidate_token(self): | ||
| async with self._lock: | ||
| self._token = None | ||
|
|
||
| def _create_path(self, api: API): | ||
| from urllib.parse import urljoin | ||
|
|
||
| if self.base_url == "INTERNAL": | ||
| base_with_prefix = urljoin(api.endpoint.service, api.endpoint.prefix) | ||
| return urljoin(base_with_prefix, api.path) | ||
| else: | ||
| base_with_prefix = urljoin(self.base_url, api.endpoint.prefix) | ||
| return urljoin(base_with_prefix, api.path) | ||
|
|
||
| async def aclose(self): | ||
| if self._owns_client: | ||
| await self._http_client.aclose() | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -48,6 +48,7 @@ | |
| from pyatlan.client.file import FileClient | ||
| from pyatlan.client.group import GroupClient | ||
| from pyatlan.client.impersonate import ImpersonationClient | ||
| from pyatlan.client.oauth import OAuthTokenManager | ||
| from pyatlan.client.open_lineage import OpenLineageClient | ||
| from pyatlan.client.query import QueryClient | ||
| from pyatlan.client.role import RoleClient | ||
|
|
@@ -127,7 +128,9 @@ def log_response(response, *args, **kwargs): | |
|
|
||
| class AtlanClient(BaseSettings): | ||
| base_url: Union[Literal["INTERNAL"], HttpUrl] | ||
| api_key: str | ||
| api_key: Optional[str] = None | ||
| oauth_client_id: Optional[str] = None | ||
| oauth_client_secret: Optional[str] = None | ||
| connect_timeout: float = 30.0 # 30 secs | ||
| read_timeout: float = 900.0 # 15 mins | ||
| retry: Retry = DEFAULT_RETRY | ||
|
|
@@ -137,6 +140,7 @@ class AtlanClient(BaseSettings): | |
| _session: httpx.Client = PrivateAttr() | ||
| _request_params: dict = PrivateAttr() | ||
| _user_id: Optional[str] = PrivateAttr(default=None) | ||
| _oauth_token_manager: Optional[Any] = PrivateAttr(default=None) | ||
| _workflow_client: Optional[WorkflowClient] = PrivateAttr(default=None) | ||
| _credential_client: Optional[CredentialClient] = PrivateAttr(default=None) | ||
| _admin_client: Optional[AdminClient] = PrivateAttr(default=None) | ||
|
|
@@ -172,11 +176,33 @@ class Config: | |
|
|
||
| def __init__(self, **data): | ||
| super().__init__(**data) | ||
| self._request_params = ( | ||
| {"headers": {"authorization": f"Bearer {self.api_key}"}} | ||
| if self.api_key and self.api_key.strip() | ||
| else {"headers": {}} | ||
| ) | ||
|
|
||
| if self.oauth_client_id and self.oauth_client_secret and self.api_key is None: | ||
| LOGGER.debug("API KEY not provided. Using OAuth flow for authentication") | ||
|
|
||
| final_base_url = self.base_url or os.environ.get( | ||
| "ATLAN_BASE_URL", "INTERNAL" | ||
| ) | ||
| final_oauth_client_id = self.oauth_client_id or os.environ.get( | ||
| "ATLAN_OAUTH_CLIENT_ID" | ||
| ) | ||
| final_oauth_client_secret = self.oauth_client_secret or os.environ.get( | ||
| "ATLAN_OAUTH_CLIENT_SECRET" | ||
| ) | ||
| self._oauth_token_manager = OAuthTokenManager( | ||
| base_url=final_base_url, | ||
| client_id=final_oauth_client_id, | ||
| client_secret=final_oauth_client_secret, | ||
| connect_timeout=self.connect_timeout, | ||
| read_timeout=self.read_timeout, | ||
| ) | ||
| self._request_params = {"headers": {}} | ||
| else: | ||
| self._request_params = ( | ||
| {"headers": {"authorization": f"Bearer {self.api_key}"}} | ||
| if self.api_key and self.api_key.strip() | ||
| else {"headers": {}} | ||
| ) | ||
vaibhavatlan marked this conversation as resolved.
Show resolved
Hide resolved
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Bug: Sync Client OAuth Leaks HTTP ResourcesThe sync |
||
|
|
||
| # Build proxy/SSL configuration with environment variable fallback | ||
| transport_kwargs = self._build_transport_proxy_config(data) | ||
|
|
@@ -691,7 +717,7 @@ def _call_api_internal( | |
| # Retry with impersonation (if _user_id is present) | ||
| # on authentication failure (token may have expired) | ||
| if ( | ||
| self._user_id | ||
| (self._user_id or self._oauth_token_manager) | ||
| and not self._401_has_retried.get() | ||
| and response.status_code | ||
| == ErrorCode.AUTHENTICATION_PASSTHROUGH.http_error_code | ||
|
|
@@ -813,6 +839,9 @@ def _create_params( | |
| self, api: API, query_params, request_obj, exclude_unset: bool = True | ||
| ): | ||
| params = copy.deepcopy(self._request_params) | ||
| if self._oauth_token_manager: | ||
| token = self._oauth_token_manager.get_token() | ||
| params["headers"]["authorization"] = f"Bearer {token}" | ||
| params["headers"]["Accept"] = api.consumes | ||
| params["headers"]["content-type"] = api.produces | ||
| if query_params is not None: | ||
|
|
@@ -846,6 +875,21 @@ def _handle_401_token_refresh( | |
|
|
||
| returns: HTTP response received after retrying the request with the refreshed token | ||
| """ | ||
| if self._oauth_token_manager: | ||
| self._oauth_token_manager.invalidate_token() | ||
| token = self._oauth_token_manager.get_token() | ||
| params["headers"]["authorization"] = f"Bearer {token}" | ||
| self._401_has_retried.set(True) | ||
| LOGGER.debug("Successfully refreshed OAuth token after 401.") | ||
| return self._call_api_internal( | ||
| api, | ||
| path, | ||
| params, | ||
| binary_data=binary_data, | ||
| download_file_path=download_file_path, | ||
| text_response=text_response, | ||
| ) | ||
|
|
||
| try: | ||
| new_token = self.impersonate.user(user_id=self._user_id) | ||
| except Exception as e: | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.