Skip to content

feat: Add support for both basic and Cognito bearer auth for Airflow API #9

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 7 additions & 7 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,22 +1,22 @@
fail_fast: true
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.6.0
rev: v5.0.0
hooks:
# Git style
- id: check-merge-conflict
- id: check-symlinks
- id: trailing-whitespace

- repo: https://github.com/pycqa/isort
rev: 5.13.2
rev: 6.0.1
hooks:
- id: isort
args: ["--profile", "black", "--filter-files"]

# Using this mirror lets us use mypyc-compiled black, which is about 2x faster
- repo: https://github.com/psf/black-pre-commit-mirror
rev: 24.4.2
rev: 25.1.0
hooks:
- id: black
# It is recommended to specify the latest version of Python
Expand All @@ -27,24 +27,24 @@ repos:

- repo: https://github.com/astral-sh/ruff-pre-commit
# Ruff version.
rev: v0.4.5
rev: v0.12.3
hooks:
- id: ruff
args: ["--ignore", "E501,E402"]

- repo: https://github.com/PyCQA/bandit
rev: "1.7.8" # you must change this to newest version
rev: "1.8.6" # you must change this to newest version
hooks:
- id: bandit
args: ["--severity-level=high", "--confidence-level=high"]

- repo: https://github.com/PyCQA/prospector
rev: v1.10.3
rev: v1.17.2
hooks:
- id: prospector

- repo: https://github.com/antonbabenko/pre-commit-terraform
rev: v1.90.0 # Get the latest from: https://github.com/antonbabenko/pre-commit-terraform/releases
rev: v1.99.5 # Get the latest from: https://github.com/antonbabenko/pre-commit-terraform/releases
hooks:
# Terraform Tests
- id: terraform_fmt
Expand Down
80 changes: 80 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,86 @@ The Unity initiator is the set of compute resources that enable the routing of t

The initiator topic, an SNS topic, is the common interface that all triggers will submit events to. The initiator topic is subscribed to by the initiator SQS queue (complete with dead-letter queue for resiliency) which in turn is subscribed to by the router Lambda function. How the router Lambda routes payloads of the trigger events is defined by the router configuration YAML. The full YAML schema for the router configuration is located [here](src/unity_initiator/resources/routers_schema.yaml).

## Authentication in Router Configs

The Unity Initiator supports multiple authentication methods for submitting DAG runs to Airflow. You can use legacy Basic Auth, or Bearer token authentication using AWS Cognito (either OAuth2 or InitiateAuth flows). Choose the method that matches your Airflow API deployment and security requirements.

Below are full example router configs for each authentication method. See the [router schema](src/unity_initiator/resources/routers_schema.yaml) for all available fields.

### 1. Basic Auth (Legacy)
```yaml
initiator_config:
name: basic-auth example
payload_type:
url:
- regexes:
- '.*\\.dat$'
evaluators:
- name: eval_basic
actions:
- name: submit_dag_by_id
params:
dag_id: my_airflow_dag
airflow_base_api_endpoint: https://airflow.example.com/api/v1
auth_method: basic
airflow_username: my-airflow-username
airflow_password: my-airflow-password
```

### 2. Bearer Token (Cognito OAuth2)
```yaml
initiator_config:
name: bearer-oauth2 example
payload_type:
url:
- regexes:
- '.*\\.dat$'
evaluators:
- name: eval_oauth2
actions:
- name: submit_dag_by_id
params:
dag_id: my_airflow_dag
airflow_base_api_endpoint: https://airflow.example.com/api/v1
auth_method: bearer
cognito_token_method: oauth2
cognito_token_url: https://your-cognito-domain.auth.us-west-2.amazoncognito.com/oauth2/token
cognito_client_id: your-client-id
cognito_client_secret: your-client-secret
cognito_username: your-username
cognito_password: your-password
```

### 3. Bearer Token (Cognito InitiateAuth)
```yaml
initiator_config:
name: bearer-initiate-auth example
payload_type:
url:
- regexes:
- '.*\\.dat$'
evaluators:
- name: eval_initauth
actions:
- name: submit_dag_by_id
params:
dag_id: my_airflow_dag
airflow_base_api_endpoint: https://airflow.example.com/api/v1
auth_method: bearer
cognito_token_method: initiate_auth
cognito_region: us-west-2
cognito_client_id: your-client-id
cognito_username: your-username
cognito_password: your-password
```

**When to use each method:**
- Use `basic` for legacy Airflow deployments with HTTP Basic Auth.
- Use `bearer` with `oauth2` for OIDC/JWT-based Airflow APIs (API Gateway/ALB with Cognito OIDC).
- Use `bearer` with `initiate_auth` for AWS-native Cognito integrations (if your API expects tokens from the InitiateAuth flow).

For more advanced usage (e.g., on_success actions, multiple evaluators, or other action types), see the schema and additional documentation below.

#### How the router works

In the context of trigger events where a new file is detected (payload_type=`url`), the router Lambda extracts the URL of the new file, instantiates a router object and attempts to match it up against of set of regular expressions defined in the router configuration file. Let's consider this minimal router configuration YAML file example:
Expand Down
2 changes: 1 addition & 1 deletion src/unity_initiator/__about__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# SPDX-FileCopyrightText: 2024-present Gerald Manipon <[email protected]>
#
# SPDX-License-Identifier: Apache-2.0
__version__ = "0.0.1"
__version__ = "0.0.2"
76 changes: 73 additions & 3 deletions src/unity_initiator/actions/submit_dag_by_id.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import time
import uuid
from datetime import datetime

Expand All @@ -9,21 +10,90 @@
__all__ = ["SubmitDagByID"]


def fetch_cognito_token_oauth2(token_url, client_id, client_secret, username, password):
data = {
"grant_type": "password",
"client_id": client_id,
"client_secret": client_secret,
"username": username,
"password": password,
"scope": "openid",
}
response = httpx.post(token_url, data=data)
response.raise_for_status()
token_data = response.json()
return token_data["access_token"], time.time() + token_data.get("expires_in", 3600)


def fetch_cognito_token_initiate_auth(region, client_id, username, password):
url = f"https://cognito-idp.{region}.amazonaws.com"
payload = {
"AuthParameters": {"USERNAME": f"{username}", "PASSWORD": f"{password}"},
"AuthFlow": "USER_PASSWORD_AUTH",
"ClientId": f"{client_id}",
}
headers = {
"X-Amz-Target": "AWSCognitoIdentityProviderService.InitiateAuth",
"Content-Type": "application/x-amz-json-1.1",
}
res = httpx.post(url, json=payload, headers=headers).json()
if "AuthenticationResult" in res:
access_token = res["AuthenticationResult"]["AccessToken"]
# Cognito AccessToken is valid for 1 hour by default
return access_token, time.time() + 3600
raise RuntimeError(f"Failed to fetch Cognito token: {res}")


class SubmitDagByID(Action):
def __init__(self, payload, payload_info, params):
super().__init__(payload, payload_info, params)
logger.info("instantiated %s", __class__.__name__)

def execute(self):
# TODO: flesh this method out completely in accordance with:
# https://airflow.apache.org/docs/apache-airflow/stable/stable-rest-api-ref.html#operation/post_dag_run
logger.debug("executing execute in %s", __class__.__name__)
url = f"{self._params['airflow_base_api_endpoint']}/dags/{self._params['dag_id']}/dagRuns"
logger.info("url: %s", url)
dag_run_id = str(uuid.uuid4())
logical_date = datetime.now().strftime("%Y-%m-%dT%H:%M:%S.%fZ")
headers = {"Content-Type": "application/json", "Accept": "application/json"}
auth = (self._params["airflow_username"], self._params["airflow_password"])
auth = None

# Determine authentication method
auth_method = self._params.get("auth_method", "basic")
if auth_method == "bearer":
# Support both Cognito token fetch methods
token = self._params.get("bearer_token")
expiry = self._params.get("bearer_token_expiry", 0)
now = time.time()
if not token or now > expiry - 60: # refresh 1 min before expiry
token_method = self._params.get("cognito_token_method", "oauth2")
if token_method == "initiate_auth":
logger.info("Fetching Cognito bearer token using InitiateAuth...")
token, expiry = fetch_cognito_token_initiate_auth(
self._params["cognito_region"],
self._params["cognito_client_id"],
self._params["cognito_username"],
self._params["cognito_password"],
)
else:
logger.info(
"Fetching Cognito bearer token using OAuth2 password grant..."
)
token, expiry = fetch_cognito_token_oauth2(
self._params["cognito_token_url"],
self._params["cognito_client_id"],
self._params["cognito_client_secret"],
self._params["cognito_username"],
self._params["cognito_password"],
)
self._params["bearer_token"] = token
self._params["bearer_token_expiry"] = expiry
headers["Authorization"] = f"Bearer {token}"
auth = None
else:
# Default to basic auth
auth = (self._params["airflow_username"], self._params["airflow_password"])

body = {
"dag_run_id": dag_run_id,
"logical_date": logical_date,
Expand Down
16 changes: 16 additions & 0 deletions src/unity_initiator/resources/routers_schema.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,24 @@ submit_dag_by_id_action:
params:
dag_id: str()
airflow_base_api_endpoint: str(required=False)
# Auth method: 'basic' (default) or 'bearer'
auth_method: enum("basic", "bearer", required=False)
# For basic auth (legacy)
airflow_username: str(required=False)
airflow_password: str(required=False)
# For bearer token auth (Cognito)
cognito_client_id: str(required=False)
cognito_client_secret: str(required=False)
cognito_token_url: str(required=False)
cognito_username: str(required=False)
cognito_password: str(required=False)
# Optionally, allow passing a pre-fetched bearer token and its expiry
bearer_token: str(required=False)
bearer_token_expiry: int(required=False)
# Cognito token method: 'oauth2' (default) or 'initiate_auth'
cognito_token_method: enum("oauth2", "initiate_auth", required=False)
# Cognito region (required for initiate_auth)
cognito_region: str(required=False)
on_success: include("on_success_actions", required=False)

# Configuration for submitting a payload to an SNS topic.
Expand Down
4 changes: 2 additions & 2 deletions terraform-unity/centralized_log_group/README.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# terraform-unity

<!-- BEGINNING OF PRE-COMMIT-TERRAFORM DOCS HOOK -->
<!-- BEGIN_TF_DOCS -->
## Requirements

| Name | Version |
Expand Down Expand Up @@ -37,4 +37,4 @@ No modules.
| Name | Description |
|------|-------------|
| <a name="output_centralized_log_group_name"></a> [centralized\_log\_group\_name](#output\_centralized\_log\_group\_name) | The name of the centralized log group |
<!-- END OF PRE-COMMIT-TERRAFORM DOCS HOOK -->
<!-- END_TF_DOCS -->
4 changes: 2 additions & 2 deletions terraform-unity/evaluators/sns-sqs-lambda/README.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# sns_sqs_lambda

<!-- BEGINNING OF PRE-COMMIT-TERRAFORM DOCS HOOK -->
<!-- BEGIN_TF_DOCS -->
## Requirements

| Name | Version |
Expand Down Expand Up @@ -62,4 +62,4 @@ No modules.
| Name | Description |
|------|-------------|
| <a name="output_evaluator_topic_arn"></a> [evaluator\_topic\_arn](#output\_evaluator\_topic\_arn) | The ARN of the evaluator SNS topic |
<!-- END OF PRE-COMMIT-TERRAFORM DOCS HOOK -->
<!-- END_TF_DOCS -->
4 changes: 2 additions & 2 deletions terraform-unity/initiator/README.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# terraform-unity

<!-- BEGINNING OF PRE-COMMIT-TERRAFORM DOCS HOOK -->
<!-- BEGIN_TF_DOCS -->
## Requirements

| Name | Version |
Expand Down Expand Up @@ -60,4 +60,4 @@ No modules.
| Name | Description |
|------|-------------|
| <a name="output_initiator_topic_arn"></a> [initiator\_topic\_arn](#output\_initiator\_topic\_arn) | The ARN of the initiator SNS topic |
<!-- END OF PRE-COMMIT-TERRAFORM DOCS HOOK -->
<!-- END_TF_DOCS -->
Loading