Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions .env.example
Original file line number Diff line number Diff line change
@@ -1,2 +1,8 @@
ATLAN_BASE_URL=your_tenant_base_url

#API KEY based authentication
ATLAN_API_KEY=your_api_key

#OAuth based authentication
ATLAN_OAUTH_CLIENT_ID=your_oauth_client_id
ATLAN_OAUTH_CLIENT_SECRET=your_oauth_client_secret
51 changes: 49 additions & 2 deletions pyatlan/client/aio/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from contextlib import _AsyncGeneratorContextManager
from http import HTTPStatus
from types import SimpleNamespace
from typing import Optional
from typing import Any, Optional

import httpx
from httpx_retries.retry import Retry
Expand Down Expand Up @@ -90,6 +90,7 @@ class AsyncAtlanClient(AtlanClient):
"""

_async_session: Optional[httpx.AsyncClient] = PrivateAttr(default=None)
_async_oauth_token_manager: Optional[Any] = PrivateAttr(default=None)
_async_admin_client: Optional[AsyncAdminClient] = PrivateAttr(default=None)
_async_asset_client: Optional[AsyncAssetClient] = PrivateAttr(default=None)
_async_audit_client: Optional[AsyncAuditClient] = PrivateAttr(default=None)
Expand Down Expand Up @@ -134,6 +135,34 @@ def __init__(self, **kwargs):
# Initialize sync client (handles all validation, env vars, etc.)
super().__init__(**kwargs)

if self.oauth_client_id and self.oauth_client_secret and self.api_key is None:
LOGGER.debug(
"API Key not provided. Using Async OAuth flow for authentication"
)
from pyatlan.client.aio.oauth import AsyncOAuthTokenManager

if self._oauth_token_manager:
LOGGER.debug("Sync oauth flow open. Closing it for Async oauth flow")
self._oauth_token_manager.close()
self._oauth_token_manager = None

final_base_url = self.base_url or os.environ.get(
"ATLAN_BASE_URL", "INTERNAL"
)
final_oauth_client_id = self.oauth_client_id or os.environ.get(
"ATLAN_OAUTH_CLIENT_ID"
)
final_oauth_client_secret = self.oauth_client_secret or os.environ.get(
"ATLAN_OAUTH_CLIENT_SECRET"
)
self._async_oauth_token_manager = AsyncOAuthTokenManager(
base_url=final_base_url,
client_id=final_oauth_client_id,
client_secret=final_oauth_client_secret,
connect_timeout=self.connect_timeout,
read_timeout=self.read_timeout,
)

# Build proxy/SSL configuration (reuse from sync client)
transport_kwargs = self._build_transport_proxy_config(kwargs)

Expand Down Expand Up @@ -438,6 +467,9 @@ async def _create_params(
Async version of _create_params that uses AsyncAtlanRequest for AtlanObject instances.
"""
params = copy.deepcopy(self._request_params)
if self._async_oauth_token_manager:
token = await self._async_oauth_token_manager.get_token()
params["headers"]["authorization"] = f"Bearer {token}"
params["headers"]["Accept"] = api.consumes
params["headers"]["content-type"] = api.produces
if query_params is not None:
Expand Down Expand Up @@ -687,7 +719,7 @@ async def _handle_error_response(

# Retry with impersonation (if _user_id is present) on authentication failure
if (
self._user_id
(self._user_id or self._async_oauth_token_manager)
and not self._401_has_retried.get()
and response.status_code
== ErrorCode.AUTHENTICATION_PASSTHROUGH.http_error_code
Expand Down Expand Up @@ -746,6 +778,21 @@ async def _handle_401_token_refresh(
Async version of token refresh and retry logic.
Handles token refresh and retries the API request upon a 401 Unauthorized response.
"""
if self._async_oauth_token_manager:
await self._async_oauth_token_manager.invalidate_token()
token = await self._async_oauth_token_manager.get_token()
params["headers"]["authorization"] = f"Bearer {token}"
self._401_has_retried.set(True)
LOGGER.debug("Successfully refreshed OAuth token after 401.")
return await self._call_api_internal(
api,
path,
params,
binary_data=binary_data,
download_file_path=download_file_path,
text_response=text_response,
)

try:
# Use sync impersonation call since it's a quick API call
new_token = await self.impersonate.user(user_id=self._user_id)
Expand Down
91 changes: 91 additions & 0 deletions pyatlan/client/aio/oauth.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
# SPDX-License-Identifier: Apache-2.0
# Copyright 2025 Atlan Pte. Ltd.
import asyncio
import time
from typing import Optional

import httpx
from authlib.oauth2.rfc6749 import OAuth2Token

from pyatlan.client.constants import GET_OAUTH_CLIENT
from pyatlan.utils import API


class AsyncOAuthTokenManager:
def __init__(
self,
base_url: str,
client_id: str,
client_secret: str,
http_client: Optional[httpx.AsyncClient] = None,
connect_timeout: float = 30.0,
read_timeout: float = 900.0,
):
self.base_url = base_url
self.client_id = client_id
self.client_secret = client_secret
self.token_url = self._create_path(GET_OAUTH_CLIENT)
self._lock = asyncio.Lock()
self._http_client = http_client or httpx.AsyncClient(
timeout=httpx.Timeout(
connect=connect_timeout, read=read_timeout, write=30.0, pool=30.0
)
)
self._token: Optional[OAuth2Token] = None
self._owns_client = http_client is None

async def get_token(self) -> str:
async with self._lock:
if self._token and not self._token.is_expired():
return str(self._token["access_token"])

response = await self._http_client.post(
self.token_url,
json={
"clientId": self.client_id,
"clientSecret": self.client_secret,
},
headers={"Content-Type": "application/json"},
)
response.raise_for_status()

data = response.json()
access_token = data.get("accessToken") or data.get("access_token")

if not access_token:
raise ValueError(
f"OAuth token response missing 'accessToken' field. "
f"Response keys: {list(data.keys())}"
)

expires_in = data.get("expiresIn") or data.get("expires_in", 600)

self._token = OAuth2Token(
{
"access_token": access_token,
"token_type": data.get("tokenType")
or data.get("token_type", "Bearer"),
"expires_in": expires_in,
"expires_at": int(time.time()) + expires_in,
}
)

return access_token

async def invalidate_token(self):
async with self._lock:
self._token = None

def _create_path(self, api: API):
from urllib.parse import urljoin

if self.base_url == "INTERNAL":
base_with_prefix = urljoin(api.endpoint.service, api.endpoint.prefix)
return urljoin(base_with_prefix, api.path)
else:
base_with_prefix = urljoin(self.base_url, api.endpoint.prefix)
return urljoin(base_with_prefix, api.path)

async def aclose(self):
if self._owns_client:
await self._http_client.aclose()
58 changes: 51 additions & 7 deletions pyatlan/client/atlan.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
from pyatlan.client.file import FileClient
from pyatlan.client.group import GroupClient
from pyatlan.client.impersonate import ImpersonationClient
from pyatlan.client.oauth import OAuthTokenManager
from pyatlan.client.open_lineage import OpenLineageClient
from pyatlan.client.query import QueryClient
from pyatlan.client.role import RoleClient
Expand Down Expand Up @@ -127,7 +128,9 @@ def log_response(response, *args, **kwargs):

class AtlanClient(BaseSettings):
base_url: Union[Literal["INTERNAL"], HttpUrl]
api_key: str
api_key: Optional[str] = None
oauth_client_id: Optional[str] = None
oauth_client_secret: Optional[str] = None
connect_timeout: float = 30.0 # 30 secs
read_timeout: float = 900.0 # 15 mins
retry: Retry = DEFAULT_RETRY
Expand All @@ -137,6 +140,7 @@ class AtlanClient(BaseSettings):
_session: httpx.Client = PrivateAttr()
_request_params: dict = PrivateAttr()
_user_id: Optional[str] = PrivateAttr(default=None)
_oauth_token_manager: Optional[Any] = PrivateAttr(default=None)
_workflow_client: Optional[WorkflowClient] = PrivateAttr(default=None)
_credential_client: Optional[CredentialClient] = PrivateAttr(default=None)
_admin_client: Optional[AdminClient] = PrivateAttr(default=None)
Expand Down Expand Up @@ -172,11 +176,33 @@ class Config:

def __init__(self, **data):
super().__init__(**data)
self._request_params = (
{"headers": {"authorization": f"Bearer {self.api_key}"}}
if self.api_key and self.api_key.strip()
else {"headers": {}}
)

if self.oauth_client_id and self.oauth_client_secret and self.api_key is None:
LOGGER.debug("API KEY not provided. Using OAuth flow for authentication")

final_base_url = self.base_url or os.environ.get(
"ATLAN_BASE_URL", "INTERNAL"
)
final_oauth_client_id = self.oauth_client_id or os.environ.get(
"ATLAN_OAUTH_CLIENT_ID"
)
final_oauth_client_secret = self.oauth_client_secret or os.environ.get(
"ATLAN_OAUTH_CLIENT_SECRET"
)
self._oauth_token_manager = OAuthTokenManager(
base_url=final_base_url,
client_id=final_oauth_client_id,
client_secret=final_oauth_client_secret,
connect_timeout=self.connect_timeout,
read_timeout=self.read_timeout,
)
self._request_params = {"headers": {}}
else:
self._request_params = (
{"headers": {"authorization": f"Bearer {self.api_key}"}}
if self.api_key and self.api_key.strip()
else {"headers": {}}
)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bug: Sync Client OAuth Leaks HTTP Resources

The sync AtlanClient creates _oauth_token_manager when OAuth credentials are provided but has no cleanup mechanism (no close() method or destructor). The token manager owns an httpx.Client that needs to be closed via self._oauth_token_manager.close(), but this never happens, causing a resource leak with unclosed HTTP connections whenever OAuth authentication is used with the sync client.

Fix in Cursor Fix in Web


# Build proxy/SSL configuration with environment variable fallback
transport_kwargs = self._build_transport_proxy_config(data)
Expand Down Expand Up @@ -691,7 +717,7 @@ def _call_api_internal(
# Retry with impersonation (if _user_id is present)
# on authentication failure (token may have expired)
if (
self._user_id
(self._user_id or self._oauth_token_manager)
and not self._401_has_retried.get()
and response.status_code
== ErrorCode.AUTHENTICATION_PASSTHROUGH.http_error_code
Expand Down Expand Up @@ -813,6 +839,9 @@ def _create_params(
self, api: API, query_params, request_obj, exclude_unset: bool = True
):
params = copy.deepcopy(self._request_params)
if self._oauth_token_manager:
token = self._oauth_token_manager.get_token()
params["headers"]["authorization"] = f"Bearer {token}"
params["headers"]["Accept"] = api.consumes
params["headers"]["content-type"] = api.produces
if query_params is not None:
Expand Down Expand Up @@ -846,6 +875,21 @@ def _handle_401_token_refresh(

returns: HTTP response received after retrying the request with the refreshed token
"""
if self._oauth_token_manager:
self._oauth_token_manager.invalidate_token()
token = self._oauth_token_manager.get_token()
params["headers"]["authorization"] = f"Bearer {token}"
self._401_has_retried.set(True)
LOGGER.debug("Successfully refreshed OAuth token after 401.")
return self._call_api_internal(
api,
path,
params,
binary_data=binary_data,
download_file_path=download_file_path,
text_response=text_response,
)

try:
new_token = self.impersonate.user(user_id=self._user_id)
except Exception as e:
Expand Down
8 changes: 8 additions & 0 deletions pyatlan/client/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,14 @@
GET_WHOAMI_USER = API(
WHOAMI_API, HTTPMethod.GET, HTTPStatus.OK, endpoint=EndPoint.HERACLES
)

# oauth client authentinatication
GET_OAUTH_CLIENT = API(
"oauth-clients/token",
HTTPMethod.POST,
HTTPStatus.OK,
endpoint=EndPoint.HERACLES,
)
# SQL parsing APIs
PARSE_QUERY = API(
f"{QUERY_API}/parse", HTTPMethod.POST, HTTPStatus.OK, endpoint=EndPoint.HEKA
Expand Down
Loading
Loading