diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
index 986fcc3..7a3c2c6 100644
--- a/.pre-commit-config.yaml
+++ b/.pre-commit-config.yaml
@@ -1,7 +1,7 @@
fail_fast: true
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
- rev: v4.6.0
+ rev: v5.0.0
hooks:
# Git style
- id: check-merge-conflict
@@ -9,14 +9,14 @@ repos:
- id: trailing-whitespace
- repo: https://github.com/pycqa/isort
- rev: 5.13.2
+ rev: 6.0.1
hooks:
- id: isort
args: ["--profile", "black", "--filter-files"]
# Using this mirror lets us use mypyc-compiled black, which is about 2x faster
- repo: https://github.com/psf/black-pre-commit-mirror
- rev: 24.4.2
+ rev: 25.1.0
hooks:
- id: black
# It is recommended to specify the latest version of Python
@@ -27,24 +27,24 @@ repos:
- repo: https://github.com/astral-sh/ruff-pre-commit
# Ruff version.
- rev: v0.4.5
+ rev: v0.12.3
hooks:
- id: ruff
args: ["--ignore", "E501,E402"]
- repo: https://github.com/PyCQA/bandit
- rev: "1.7.8" # you must change this to newest version
+ rev: "1.8.6" # you must change this to newest version
hooks:
- id: bandit
args: ["--severity-level=high", "--confidence-level=high"]
- repo: https://github.com/PyCQA/prospector
- rev: v1.10.3
+ rev: v1.17.2
hooks:
- id: prospector
- repo: https://github.com/antonbabenko/pre-commit-terraform
- rev: v1.90.0 # Get the latest from: https://github.com/antonbabenko/pre-commit-terraform/releases
+ rev: v1.99.5 # Get the latest from: https://github.com/antonbabenko/pre-commit-terraform/releases
hooks:
# Terraform Tests
- id: terraform_fmt
diff --git a/.prospector.yaml b/.prospector.yaml
index 8995ea9..f76baae 100644
--- a/.prospector.yaml
+++ b/.prospector.yaml
@@ -10,6 +10,11 @@ pylint:
disable:
- import-error
- django-not-available
+ - import-outside-toplevel
+ - no-else-return
+ - consider-using-sys-exit
+ - too-many-arguments
+ - too-many-positional-arguments
options:
max-line-length: 159
diff --git a/CHANGELOG.md b/CHANGELOG.md
index 7151834..7747ec5 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -5,7 +5,32 @@ All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
-## [X.Y.Z] - 2022-MM-DD
+## [0.0.2] - 2024-12-19
+
+### Added
+
+- Cognito token authentication support with automatic refresh
+- Bearer token authentication for enhanced security
+- `TokenManager` class for token lifecycle management
+- `fetch_cognito_token()` function for Cognito integration
+- `get_auth_headers()` utility for authentication headers
+- Comprehensive test coverage for authentication features
+- Example scripts demonstrating token usage
+- Updated documentation for authentication configuration
+
+### Changed
+
+- Enhanced `SubmitDagByID` action to support multiple authentication methods
+- Added `httpx` dependency for modern HTTP client functionality
+- Maintained backward compatibility with existing basic auth
+
+### Security
+
+- Replaced basic authentication with more secure Bearer token authentication
+- Added automatic token refresh to prevent authentication failures
+- Implemented token caching to reduce API calls to Cognito
+
+## [0.0.1] - 2022-MM-DD
### Added
diff --git a/README.md b/README.md
index fa41d3f..b685dc2 100644
--- a/README.md
+++ b/README.md
@@ -160,6 +160,86 @@ and a trigger event payload for a new file that was triggered:
In this case, the router sees that the action is `submit_dag_by_id` and thus makes a REST call to SPS to submit the URL payload, payload info, and `on_success` parameters as a DAG run. If the evaulator, running now as a DAG in SPS instead of an AWS Lambda function, successfully evaluates that everything is ready for this input file, it can proceed to submit a DAG run for the `submit_nisar_l0a_te_dag` DAG in the underlying SPS.
+### Authentication for Airflow DAG Submissions
+
+The `submit_dag_by_id` action supports multiple authentication methods for connecting to Airflow REST APIs. The authentication method is determined by the parameters provided in the router configuration:
+
+#### 1. Bearer Token Authentication (Recommended)
+Use a direct bearer token for authentication. This is the most secure method:
+
+```yaml
+actions:
+ - name: submit_dag_by_id
+ params:
+ dag_id: example_dag
+ airflow_base_api_endpoint: https://airflow.example.com/api/v1
+ airflow_token: ${AIRFLOW_BEARER_TOKEN} # Bearer token
+```
+
+#### 2. OAuth2 Authentication (For Proxy Servers)
+Use OAuth2 authorization code flow for proxy authentication:
+
+```yaml
+actions:
+ - name: submit_dag_by_id
+ params:
+ dag_id: example_dag
+ airflow_base_api_endpoint: https://proxy.example.com/api/v1
+ oauth2_cognito_domain: your-domain.auth.us-west-2.amazoncognito.com
+ oauth2_client_id: your-oauth2-client-id
+ oauth2_redirect_uri: https://your-app.com/callback
+ oauth2_scope: openid email profile # Optional, defaults to "openid email profile"
+ oauth2_region: us-west-2 # Optional, defaults to us-west-2
+ oauth2_verify_ssl: true # Optional, defaults to true for security
+```
+
+**OAuth2 Flow Setup**:
+1. Use the provided `oauth2_token_init.py` script to initialize tokens
+2. The script will guide you through the authorization flow
+3. Tokens are automatically refreshed when needed
+
+#### 3. Cognito Token Authentication
+Use Unity Cognito credentials to automatically fetch and refresh tokens:
+
+```yaml
+actions:
+ - name: submit_dag_by_id
+ params:
+ dag_id: example_dag
+ airflow_base_api_endpoint: https://airflow.example.com/api/v1
+ unity_username: ${UNITY_USERNAME}
+ unity_password: ${UNITY_PASSWORD}
+ unity_client_id: ${UNITY_CLIENT_ID}
+ unity_region: us-west-2 # Optional, defaults to us-west-2
+```
+
+#### 4. Basic Authentication (Legacy)
+Use username/password for basic authentication (less secure):
+
+```yaml
+actions:
+ - name: submit_dag_by_id
+ params:
+ dag_id: example_dag
+ airflow_base_api_endpoint: https://airflow.example.com/api/v1
+ airflow_username: ${AIRFLOW_USERNAME}
+ airflow_password: ${AIRFLOW_PASSWORD}
+```
+
+#### Authentication Priority
+The system will use authentication in this order:
+1. **Bearer token** (if `airflow_token` is provided)
+2. **OAuth2 token** (if OAuth2 credentials are provided)
+3. **Cognito token** (if Unity credentials are provided)
+4. **Basic auth** (if username/password are provided)
+5. **No authentication** (if no credentials are provided)
+
+#### Token Management
+When using Cognito authentication:
+- Tokens are automatically cached and refreshed 5 minutes before expiration
+- Failed token refresh attempts fall back to credential-based fetching
+- No manual token management required
+
+
## Requirements
| Name | Version |
@@ -37,4 +37,4 @@ No modules.
| Name | Description |
|------|-------------|
| [centralized\_log\_group\_name](#output\_centralized\_log\_group\_name) | The name of the centralized log group |
-
+
diff --git a/terraform-unity/evaluators/sns-sqs-lambda/README.md b/terraform-unity/evaluators/sns-sqs-lambda/README.md
index 9d1aad5..8b8f2f8 100644
--- a/terraform-unity/evaluators/sns-sqs-lambda/README.md
+++ b/terraform-unity/evaluators/sns-sqs-lambda/README.md
@@ -1,6 +1,6 @@
# sns_sqs_lambda
-
+
## Requirements
| Name | Version |
@@ -62,4 +62,4 @@ No modules.
| Name | Description |
|------|-------------|
| [evaluator\_topic\_arn](#output\_evaluator\_topic\_arn) | The ARN of the evaluator SNS topic |
-
+
diff --git a/terraform-unity/initiator/README.md b/terraform-unity/initiator/README.md
index 64d43b7..4d79947 100644
--- a/terraform-unity/initiator/README.md
+++ b/terraform-unity/initiator/README.md
@@ -1,6 +1,6 @@
# terraform-unity
-
+
## Requirements
| Name | Version |
@@ -60,4 +60,4 @@ No modules.
| Name | Description |
|------|-------------|
| [initiator\_topic\_arn](#output\_initiator\_topic\_arn) | The ARN of the initiator SNS topic |
-
+
diff --git a/terraform-unity/triggers/cmr-query/README.md b/terraform-unity/triggers/cmr-query/README.md
index 64511ff..2fd97bf 100644
--- a/terraform-unity/triggers/cmr-query/README.md
+++ b/terraform-unity/triggers/cmr-query/README.md
@@ -1,6 +1,6 @@
# scheduled_task
-
+
## Requirements
| Name | Version |
@@ -63,4 +63,4 @@ No modules.
## Outputs
No outputs.
-
+
diff --git a/terraform-unity/triggers/s3-bucket-notification/README.md b/terraform-unity/triggers/s3-bucket-notification/README.md
index 14c7a4e..49f3961 100644
--- a/terraform-unity/triggers/s3-bucket-notification/README.md
+++ b/terraform-unity/triggers/s3-bucket-notification/README.md
@@ -1,6 +1,6 @@
# s3_bucket_notification
-
+
## Requirements
| Name | Version |
@@ -38,4 +38,4 @@ No modules.
## Outputs
No outputs.
-
+
diff --git a/terraform-unity/triggers/scheduled-task-instrumented/README.md b/terraform-unity/triggers/scheduled-task-instrumented/README.md
index ecdd896..52f16e3 100644
--- a/terraform-unity/triggers/scheduled-task-instrumented/README.md
+++ b/terraform-unity/triggers/scheduled-task-instrumented/README.md
@@ -1,6 +1,6 @@
# scheduled_task
-
+
## Requirements
| Name | Version |
@@ -54,4 +54,4 @@ No modules.
## Outputs
No outputs.
-
+
diff --git a/terraform-unity/triggers/scheduled-task/README.md b/terraform-unity/triggers/scheduled-task/README.md
index 163c4bd..31c9871 100644
--- a/terraform-unity/triggers/scheduled-task/README.md
+++ b/terraform-unity/triggers/scheduled-task/README.md
@@ -1,6 +1,6 @@
# scheduled_task
-
+
## Requirements
| Name | Version |
@@ -49,4 +49,4 @@ No modules.
## Outputs
No outputs.
-
+
diff --git a/tests/resources/test_router_with_auth.yaml b/tests/resources/test_router_with_auth.yaml
new file mode 100644
index 0000000..862506d
--- /dev/null
+++ b/tests/resources/test_router_with_auth.yaml
@@ -0,0 +1,55 @@
+initiator_config:
+ name: test config with authentication examples
+
+ payload_type:
+ url:
+ # Test with Bearer token authentication
+ - regexes:
+ - '/(?Pbearer_test_(?P\d{10})\.json)$'
+ evaluators:
+ - name: eval_bearer_auth
+ actions:
+ - name: submit_dag_by_id
+ params:
+ dag_id: test_bearer_auth_dag
+ airflow_base_api_endpoint: https://airflow.example.com/api/v1
+ airflow_token: test-bearer-token-123
+
+ # Test with Cognito authentication
+ - regexes:
+ - '/(?Pcognito_test_(?P\d{10})\.json)$'
+ evaluators:
+ - name: eval_cognito_auth
+ actions:
+ - name: submit_dag_by_id
+ params:
+ dag_id: test_cognito_auth_dag
+ airflow_base_api_endpoint: https://airflow.example.com/api/v1
+ unity_username: testuser
+ unity_password: testpass
+ unity_client_id: test-client-id
+ unity_region: us-west-2
+
+ # Test with Basic authentication
+ - regexes:
+ - '/(?Pbasic_test_(?P\d{10})\.json)$'
+ evaluators:
+ - name: eval_basic_auth
+ actions:
+ - name: submit_dag_by_id
+ params:
+ dag_id: test_basic_auth_dag
+ airflow_base_api_endpoint: https://airflow.example.com/api/v1
+ airflow_username: testuser
+ airflow_password: testpass
+
+ # Test with no authentication
+ - regexes:
+ - '/(?Pno_auth_test_(?P\d{10})\.json)$'
+ evaluators:
+ - name: eval_no_auth
+ actions:
+ - name: submit_dag_by_id
+ params:
+ dag_id: test_no_auth_dag
+ airflow_base_api_endpoint: https://airflow.example.com/api/v1
\ No newline at end of file
diff --git a/tests/test_auth_utils.py b/tests/test_auth_utils.py
new file mode 100644
index 0000000..04e6e78
--- /dev/null
+++ b/tests/test_auth_utils.py
@@ -0,0 +1,272 @@
+import time
+from unittest.mock import Mock, patch
+
+from unity_initiator.utils.auth_utils import (
+ TokenInfo,
+ TokenManager,
+ fetch_cognito_token,
+ get_auth_headers,
+)
+
+
+class TestAuthUtils:
+ """Test authentication utilities."""
+
+ def test_get_auth_headers_basic(self):
+ """Test basic authentication headers."""
+ headers = get_auth_headers(
+ auth_type="basic", username="testuser", password="testpass"
+ )
+
+ assert "Authorization" in headers
+ assert headers["Authorization"].startswith("Basic ")
+ assert headers["Content-Type"] == "application/json"
+ assert headers["Accept"] == "application/json"
+
+ def test_get_auth_headers_bearer(self):
+ """Test bearer token authentication headers."""
+ token = "test-token-123"
+ headers = get_auth_headers(auth_type="bearer", token=token)
+
+ assert "Authorization" in headers
+ assert headers["Authorization"] == f"Bearer {token}"
+ assert headers["Content-Type"] == "application/json"
+ assert headers["Accept"] == "application/json"
+
+ def test_get_auth_headers_no_auth(self):
+ """Test headers without authentication."""
+ headers = get_auth_headers()
+
+ assert "Authorization" not in headers
+ assert headers["Content-Type"] == "application/json"
+ assert headers["Accept"] == "application/json"
+
+ @patch("unity_initiator.utils.auth_utils.TokenManager")
+ def test_fetch_cognito_token_success(self, mock_token_manager_class):
+ """Test successful Cognito token fetching."""
+ # Mock TokenManager instance
+ mock_manager = Mock()
+ mock_manager.get_valid_token.return_value = "test-access-token-123"
+ mock_token_manager_class.return_value = mock_manager
+
+ token = fetch_cognito_token(
+ username="testuser", password="testpass", client_id="test-client-id"
+ )
+
+ assert token == "test-access-token-123"
+ mock_token_manager_class.assert_called_once_with(
+ "testuser",
+ "testpass",
+ "test-client-id",
+ "us-west-2",
+ )
+ mock_manager.get_valid_token.assert_called_once()
+
+ @patch("unity_initiator.utils.auth_utils.TokenManager")
+ def test_fetch_cognito_token_no_auth_result(self, mock_token_manager_class):
+ """Test Cognito token fetching with no authentication result."""
+ # Mock TokenManager instance that returns None
+ mock_manager = Mock()
+ mock_manager.get_valid_token.return_value = None
+ mock_token_manager_class.return_value = mock_manager
+
+ token = fetch_cognito_token(
+ username="testuser", password="testpass", client_id="test-client-id"
+ )
+
+ assert token is None
+
+ @patch("unity_initiator.utils.auth_utils.TokenManager")
+ def test_fetch_cognito_token_http_error(self, mock_token_manager_class):
+ """Test Cognito token fetching with HTTP error."""
+ # Mock TokenManager instance that returns None due to error
+ mock_manager = Mock()
+ mock_manager.get_valid_token.return_value = None
+ mock_token_manager_class.return_value = mock_manager
+
+ token = fetch_cognito_token(
+ username="testuser", password="testpass", client_id="test-client-id"
+ )
+
+ assert token is None
+
+ @patch("unity_initiator.utils.auth_utils.TokenManager")
+ def test_fetch_cognito_token_request_error(self, mock_token_manager_class):
+ """Test Cognito token fetching with request error."""
+ # Mock TokenManager instance that returns None due to error
+ mock_manager = Mock()
+ mock_manager.get_valid_token.return_value = None
+ mock_token_manager_class.return_value = mock_manager
+
+ token = fetch_cognito_token(
+ username="testuser", password="testpass", client_id="test-client-id"
+ )
+
+ assert token is None
+
+
+class TestTokenManager:
+ """Test TokenManager class."""
+
+ def test_token_manager_init(self):
+ """Test TokenManager initialization."""
+ manager = TokenManager("user", "pass", "client-id", "us-west-2")
+
+ assert manager.username == "user"
+ assert manager.password == "pass"
+ assert manager.client_id == "client-id"
+ assert manager.region == "us-west-2"
+ assert manager._token_cache is None
+ assert manager._refresh_buffer == 300
+
+ def test_token_info_dataclass(self):
+ """Test TokenInfo dataclass."""
+ token_info = TokenInfo(
+ access_token="test-token",
+ expires_at=time.time() + 3600,
+ refresh_token="refresh-token",
+ )
+
+ assert token_info.access_token == "test-token"
+ assert token_info.expires_at > time.time()
+ assert token_info.refresh_token == "refresh-token"
+
+ @patch("unity_initiator.utils.auth_utils.httpx.Client")
+ def test_get_valid_token_first_time(self, mock_client_class):
+ """Test getting token for the first time."""
+ # Mock successful response
+ mock_response = Mock()
+ mock_response.json.return_value = {
+ "AuthenticationResult": {
+ "AccessToken": "test-token-123",
+ "ExpiresIn": 3600,
+ "RefreshToken": "refresh-token-123",
+ }
+ }
+ mock_response.raise_for_status.return_value = None
+
+ mock_client = Mock()
+ mock_client.post.return_value = mock_response
+ mock_client_class.return_value.__enter__.return_value = mock_client
+ mock_client_class.return_value.__exit__.return_value = None
+
+ manager = TokenManager("user", "pass", "client-id")
+ token = manager.get_valid_token()
+
+ assert token == "test-token-123"
+ assert manager._token_cache is not None
+ assert manager._token_cache.access_token == "test-token-123"
+ assert manager._token_cache.refresh_token == "refresh-token-123"
+
+ def test_get_valid_token_cached_valid(self):
+ """Test getting token when cached token is still valid."""
+ # Create a token that expires in 1 hour
+ expires_at = time.time() + 3600
+ token_info = TokenInfo(
+ access_token="cached-token",
+ expires_at=expires_at,
+ refresh_token="refresh-token",
+ )
+
+ manager = TokenManager("user", "pass", "client-id")
+ manager._token_cache = token_info
+
+ token = manager.get_valid_token()
+
+ assert token == "cached-token"
+
+ def test_get_valid_token_cached_expired(self):
+ """Test getting token when cached token is expired."""
+ # Create a token that expired 1 hour ago
+ expires_at = time.time() - 3600
+ token_info = TokenInfo(
+ access_token="expired-token",
+ expires_at=expires_at,
+ refresh_token="refresh-token",
+ )
+
+ manager = TokenManager("user", "pass", "client-id")
+ manager._token_cache = token_info
+
+ # Mock the _fetch_new_token method
+ with patch.object(manager, "_fetch_new_token", return_value="new-token"):
+ token = manager.get_valid_token()
+
+ assert token == "new-token"
+
+ def test_get_valid_token_cached_expiring_soon(self):
+ """Test getting token when cached token is expiring soon."""
+ # Create a token that expires in 2 minutes (less than 5-minute buffer)
+ expires_at = time.time() + 120
+ token_info = TokenInfo(
+ access_token="expiring-token",
+ expires_at=expires_at,
+ refresh_token="refresh-token",
+ )
+
+ manager = TokenManager("user", "pass", "client-id")
+ manager._token_cache = token_info
+
+ # Mock the _fetch_new_token method
+ with patch.object(manager, "_fetch_new_token", return_value="new-token"):
+ token = manager.get_valid_token()
+
+ assert token == "new-token"
+
+ @patch("unity_initiator.utils.auth_utils.httpx.Client")
+ def test_refresh_token_success(self, mock_client_class):
+ """Test successful token refresh."""
+ # Mock successful refresh response
+ mock_response = Mock()
+ mock_response.json.return_value = {
+ "AuthenticationResult": {
+ "AccessToken": "refreshed-token",
+ "ExpiresIn": 3600,
+ }
+ }
+ mock_response.raise_for_status.return_value = None
+
+ mock_client = Mock()
+ mock_client.post.return_value = mock_response
+ mock_client_class.return_value.__enter__.return_value = mock_client
+ mock_client_class.return_value.__exit__.return_value = None
+
+ manager = TokenManager("user", "pass", "client-id")
+ manager._token_cache = TokenInfo(
+ access_token="old-token",
+ expires_at=time.time() - 3600, # Expired
+ refresh_token="refresh-token",
+ )
+
+ token = manager._refresh_token()
+
+ assert token == "refreshed-token"
+ assert manager._token_cache.access_token == "refreshed-token"
+
+ def test_refresh_token_no_refresh_token(self):
+ """Test refresh when no refresh token is available."""
+ manager = TokenManager("user", "pass", "client-id")
+ manager._token_cache = TokenInfo(
+ access_token="old-token",
+ expires_at=time.time() - 3600, # Expired
+ refresh_token=None, # No refresh token
+ )
+
+ # Mock the _fetch_new_token method
+ with patch.object(manager, "_fetch_new_token", return_value="new-token"):
+ token = manager._refresh_token()
+
+ assert token == "new-token"
+
+ def test_clear_cache(self):
+ """Test clearing the token cache."""
+ manager = TokenManager("user", "pass", "client-id")
+ manager._token_cache = TokenInfo(
+ access_token="test-token",
+ expires_at=time.time() + 3600,
+ refresh_token="refresh-token",
+ )
+
+ manager.clear_cache()
+
+ assert manager._token_cache is None
diff --git a/tests/test_oauth2_utils.py b/tests/test_oauth2_utils.py
new file mode 100644
index 0000000..fd8db39
--- /dev/null
+++ b/tests/test_oauth2_utils.py
@@ -0,0 +1,353 @@
+"""
+Tests for OAuth2 authentication utilities.
+"""
+
+import time
+from unittest.mock import Mock, patch
+from urllib.parse import parse_qs
+
+import httpx
+import pytest
+
+from unity_initiator.utils.oauth2_utils import (
+ OAuth2Manager,
+ extract_authorization_code_from_url,
+ get_oauth2_headers,
+)
+
+
+class TestOAuth2Manager:
+ """Test OAuth2Manager class."""
+
+ def test_init(self):
+ """Test OAuth2Manager initialization."""
+ manager = OAuth2Manager(
+ cognito_domain="test.auth.us-west-2.amazoncognito.com",
+ client_id="test-client-id",
+ redirect_uri="https://example.com/callback",
+ )
+
+ assert manager.cognito_domain == "test.auth.us-west-2.amazoncognito.com"
+ assert manager.client_id == "test-client-id"
+ assert manager.redirect_uri == "https://example.com/callback"
+ assert manager.scope == "openid email profile"
+ assert manager.region == "us-west-2"
+ assert manager.verify_ssl is True
+ assert (
+ manager.auth_endpoint
+ == "https://test.auth.us-west-2.amazoncognito.com/oauth2/authorize"
+ )
+ assert (
+ manager.token_endpoint
+ == "https://test.auth.us-west-2.amazoncognito.com/oauth2/token"
+ )
+
+ def test_init_with_verify_ssl_false(self):
+ """Test initialization with verify_ssl=False."""
+ manager = OAuth2Manager(
+ cognito_domain="test.auth.us-west-2.amazoncognito.com",
+ client_id="test-client-id",
+ redirect_uri="https://example.com/callback",
+ verify_ssl=False,
+ )
+
+ assert manager.verify_ssl is False
+
+ def test_get_authorization_url(self):
+ """Test authorization URL generation."""
+ manager = OAuth2Manager(
+ cognito_domain="test.auth.us-west-2.amazoncognito.com",
+ client_id="test-client-id",
+ redirect_uri="https://example.com/callback",
+ )
+
+ auth_url = manager.get_authorization_url()
+
+ # Parse URL and check parameters
+ parsed = httpx.URL(auth_url)
+ params = parse_qs(
+ parsed.query.decode()
+ if hasattr(parsed.query, "decode")
+ else str(parsed.query)
+ )
+
+ assert parsed.scheme == "https"
+ assert parsed.host == "test.auth.us-west-2.amazoncognito.com"
+ assert parsed.path == "/oauth2/authorize"
+ assert params["response_type"] == ["code"]
+ assert params["client_id"] == ["test-client-id"]
+ assert params["redirect_uri"] == ["https://example.com/callback"]
+ assert params["scope"] == ["openid email profile"]
+ assert "state" in params
+ assert "nonce" in params
+
+ def test_get_authorization_url_with_custom_state(self):
+ """Test authorization URL generation with custom state."""
+ manager = OAuth2Manager(
+ cognito_domain="test.auth.us-west-2.amazoncognito.com",
+ client_id="test-client-id",
+ redirect_uri="https://example.com/callback",
+ )
+
+ auth_url = manager.get_authorization_url(state="custom-state")
+
+ parsed = httpx.URL(auth_url)
+ params = parse_qs(
+ parsed.query.decode()
+ if hasattr(parsed.query, "decode")
+ else str(parsed.query)
+ )
+
+ assert params["state"] == ["custom-state"]
+
+ @patch("unity_initiator.utils.oauth2_utils.httpx.Client")
+ def test_exchange_code_for_token_success(self, mock_client_class):
+ """Test successful token exchange."""
+ # Mock response
+ mock_response = Mock()
+ mock_response.json.return_value = {
+ "access_token": "test-access-token",
+ "refresh_token": "test-refresh-token",
+ "expires_in": 3600,
+ "token_type": "Bearer",
+ }
+ mock_response.raise_for_status.return_value = None
+
+ # Mock client
+ mock_client = Mock()
+ mock_client.post.return_value = mock_response
+ mock_client_class.return_value.__enter__.return_value = mock_client
+
+ manager = OAuth2Manager(
+ cognito_domain="test.auth.us-west-2.amazoncognito.com",
+ client_id="test-client-id",
+ redirect_uri="https://example.com/callback",
+ )
+
+ result = manager.exchange_code_for_token("test-auth-code")
+
+ # Verify client was called correctly
+ mock_client.post.assert_called_once()
+ call_args = mock_client.post.call_args
+ assert (
+ call_args[0][0]
+ == "https://test.auth.us-west-2.amazoncognito.com/oauth2/token"
+ )
+
+ # Verify form data
+ form_data = call_args[1]["data"]
+ assert form_data["grant_type"] == "authorization_code"
+ assert form_data["client_id"] == "test-client-id"
+ assert form_data["code"] == "test-auth-code"
+ assert form_data["redirect_uri"] == "https://example.com/callback"
+
+ # Verify headers
+ headers = call_args[1]["headers"]
+ assert headers["Content-Type"] == "application/x-www-form-urlencoded"
+
+ # Verify result
+ assert result["access_token"] == "test-access-token"
+ assert result["refresh_token"] == "test-refresh-token"
+ assert result["expires_in"] == 3600
+
+ # Verify tokens were stored
+ assert manager._access_token == "test-access-token"
+ assert manager._refresh_token == "test-refresh-token"
+ assert manager._token_expires_at is not None
+
+ @patch("unity_initiator.utils.oauth2_utils.httpx.Client")
+ def test_exchange_code_for_token_http_error(self, mock_client_class):
+ """Test token exchange with HTTP error."""
+ # Mock HTTP error response
+ mock_response = Mock()
+ mock_response.text = "Invalid authorization code"
+ mock_response.raise_for_status.side_effect = httpx.HTTPStatusError(
+ "400 Bad Request", request=Mock(), response=mock_response
+ )
+
+ # Mock client
+ mock_client = Mock()
+ mock_client.post.return_value = mock_response
+ mock_client_class.return_value.__enter__.return_value = mock_client
+
+ manager = OAuth2Manager(
+ cognito_domain="test.auth.us-west-2.amazoncognito.com",
+ client_id="test-client-id",
+ redirect_uri="https://example.com/callback",
+ )
+
+ with pytest.raises(httpx.HTTPStatusError):
+ manager.exchange_code_for_token("invalid-code")
+
+ @patch("unity_initiator.utils.oauth2_utils.httpx.Client")
+ def test_refresh_access_token_success(self, mock_client_class):
+ """Test successful token refresh."""
+ # Mock response
+ mock_response = Mock()
+ mock_response.json.return_value = {
+ "access_token": "new-access-token",
+ "expires_in": 3600,
+ "token_type": "Bearer",
+ }
+ mock_response.raise_for_status.return_value = None
+
+ # Mock client
+ mock_client = Mock()
+ mock_client.post.return_value = mock_response
+ mock_client_class.return_value.__enter__.return_value = mock_client
+
+ manager = OAuth2Manager(
+ cognito_domain="test.auth.us-west-2.amazoncognito.com",
+ client_id="test-client-id",
+ redirect_uri="https://example.com/callback",
+ )
+
+ # Set up existing refresh token
+ manager._refresh_token = "existing-refresh-token"
+
+ result = manager.refresh_access_token()
+
+ # Verify client was called correctly
+ mock_client.post.assert_called_once()
+ call_args = mock_client.post.call_args
+
+ # Verify form data
+ form_data = call_args[1]["data"]
+ assert form_data["grant_type"] == "refresh_token"
+ assert form_data["client_id"] == "test-client-id"
+ assert form_data["refresh_token"] == "existing-refresh-token"
+
+ # Verify result
+ assert result == "new-access-token"
+ assert manager._access_token == "new-access-token"
+
+ def test_refresh_access_token_no_refresh_token(self):
+ """Test token refresh when no refresh token is available."""
+ manager = OAuth2Manager(
+ cognito_domain="test.auth.us-west-2.amazoncognito.com",
+ client_id="test-client-id",
+ redirect_uri="https://example.com/callback",
+ )
+
+ result = manager.refresh_access_token()
+ assert result is None
+
+ def test_get_valid_token_with_valid_token(self):
+ """Test getting valid token when token is still valid."""
+ manager = OAuth2Manager(
+ cognito_domain="test.auth.us-west-2.amazoncognito.com",
+ client_id="test-client-id",
+ redirect_uri="https://example.com/callback",
+ )
+
+ # Set up valid token
+ manager._access_token = "valid-token"
+ manager._token_expires_at = time.time() + 3600 # Expires in 1 hour
+
+ result = manager.get_valid_token()
+ assert result == "valid-token"
+
+ @patch.object(OAuth2Manager, "refresh_access_token")
+ def test_get_valid_token_with_expired_token(self, mock_refresh):
+ """Test getting valid token when token is expired."""
+ mock_refresh.return_value = "refreshed-token"
+
+ manager = OAuth2Manager(
+ cognito_domain="test.auth.us-west-2.amazoncognito.com",
+ client_id="test-client-id",
+ redirect_uri="https://example.com/callback",
+ )
+
+ # Set up expired token
+ manager._access_token = "expired-token"
+ manager._refresh_token = "refresh-token"
+ manager._token_expires_at = time.time() - 60 # Expired 1 minute ago
+
+ result = manager.get_valid_token()
+
+ assert result == "refreshed-token"
+ mock_refresh.assert_called_once()
+
+ def test_get_valid_token_no_tokens(self):
+ """Test getting valid token when no tokens are available."""
+ manager = OAuth2Manager(
+ cognito_domain="test.auth.us-west-2.amazoncognito.com",
+ client_id="test-client-id",
+ redirect_uri="https://example.com/callback",
+ )
+
+ result = manager.get_valid_token()
+ assert result is None
+
+ def test_clear_tokens(self):
+ """Test clearing stored tokens."""
+ manager = OAuth2Manager(
+ cognito_domain="test.auth.us-west-2.amazoncognito.com",
+ client_id="test-client-id",
+ redirect_uri="https://example.com/callback",
+ )
+
+ # Set up tokens
+ manager._access_token = "test-token"
+ manager._refresh_token = "test-refresh"
+ manager._token_expires_at = time.time() + 3600
+
+ manager.clear_tokens()
+
+ assert manager._access_token is None
+ assert manager._refresh_token is None
+ assert manager._token_expires_at is None
+
+
+class TestExtractAuthorizationCodeFromUrl:
+ """Test extract_authorization_code_from_url function."""
+
+ def test_extract_authorization_code_success(self):
+ """Test successful authorization code extraction."""
+ url = "https://example.com/callback?code=test-auth-code&state=test-state"
+
+ result = extract_authorization_code_from_url(url)
+
+ assert result == "test-auth-code"
+
+ def test_extract_authorization_code_with_error(self):
+ """Test authorization code extraction with OAuth2 error."""
+ url = "https://example.com/callback?error=access_denied&error_description=User+cancelled+authorization"
+
+ result = extract_authorization_code_from_url(url)
+
+ assert result is None
+
+ def test_extract_authorization_code_no_code(self):
+ """Test authorization code extraction when no code is present."""
+ url = "https://example.com/callback?state=test-state"
+
+ result = extract_authorization_code_from_url(url)
+
+ assert result is None
+
+ def test_extract_authorization_code_invalid_url(self):
+ """Test authorization code extraction with invalid URL."""
+ url = "not-a-valid-url"
+
+ result = extract_authorization_code_from_url(url)
+
+ assert result is None
+
+
+class TestGetOAuth2Headers:
+ """Test get_oauth2_headers function."""
+
+ def test_get_oauth2_headers(self):
+ """Test OAuth2 headers generation."""
+ token = "test-access-token"
+
+ headers = get_oauth2_headers(token)
+
+ expected_headers = {
+ "Authorization": "Bearer test-access-token",
+ "Content-Type": "application/json",
+ "Accept": "application/json",
+ }
+
+ assert headers == expected_headers
diff --git a/tests/test_submit_dag_by_id.py b/tests/test_submit_dag_by_id.py
new file mode 100644
index 0000000..12146a7
--- /dev/null
+++ b/tests/test_submit_dag_by_id.py
@@ -0,0 +1,229 @@
+from unittest.mock import Mock, patch
+
+from unity_initiator.actions.submit_dag_by_id import SubmitDagByID
+
+
+class TestSubmitDagByID:
+ """Test SubmitDagByID action."""
+
+ def test_init(self):
+ """Test initialization."""
+ payload = {"test": "data"}
+ payload_info = {"info": "test"}
+ params = {"dag_id": "test_dag"}
+
+ action = SubmitDagByID(payload, payload_info, params)
+
+ assert action._payload == payload
+ assert action._payload_info == payload_info
+ assert action._params == params
+
+ @patch("unity_initiator.actions.submit_dag_by_id.TokenManager")
+ def test_get_auth_token_with_cognito_credentials(self, mock_token_manager_class):
+ """Test token fetching with Cognito credentials."""
+ # Mock TokenManager instance
+ mock_manager = Mock()
+ mock_manager.get_valid_token.return_value = "test-token-123"
+ mock_token_manager_class.return_value = mock_manager
+
+ params = {
+ "unity_username": "testuser",
+ "unity_password": "testpass",
+ "unity_client_id": "test-client-id",
+ }
+
+ action = SubmitDagByID({}, {}, params)
+ token = action._get_auth_token()
+
+ assert token == "test-token-123"
+ mock_token_manager_class.assert_called_once_with(
+ username="testuser",
+ password="testpass",
+ client_id="test-client-id",
+ region="us-west-2",
+ )
+ mock_manager.get_valid_token.assert_called_once()
+
+ def test_get_auth_token_with_direct_token(self):
+ """Test token fetching with direct token."""
+ params = {"airflow_token": "direct-token-123"}
+
+ action = SubmitDagByID({}, {}, params)
+ token = action._get_auth_token()
+
+ assert token == "direct-token-123"
+
+ def test_get_auth_token_no_credentials(self):
+ """Test token fetching with no credentials."""
+ params = {}
+
+ action = SubmitDagByID({}, {}, params)
+ token = action._get_auth_token()
+
+ assert token is None
+
+ @patch("unity_initiator.actions.submit_dag_by_id.httpx.post")
+ @patch("unity_initiator.actions.submit_dag_by_id.get_auth_headers")
+ def test_execute_with_bearer_token(self, mock_get_headers, mock_post):
+ """Test execution with Bearer token authentication."""
+ # Mock successful response
+ mock_response = Mock()
+ mock_response.status_code = 200
+ mock_response.json.return_value = {"dag_run_id": "test-run"}
+ mock_post.return_value = mock_response
+
+ mock_get_headers.return_value = {"Authorization": "Bearer test-token"}
+
+ params = {
+ "airflow_base_api_endpoint": "https://airflow.example.com",
+ "dag_id": "test_dag",
+ "airflow_token": "test-token",
+ }
+
+ action = SubmitDagByID({"test": "data"}, {"info": "test"}, params)
+ result = action.execute()
+
+ assert result["success"] is True
+ assert result["response"]["dag_run_id"] == "test-run"
+ mock_get_headers.assert_called_once_with(auth_type="bearer", token="test-token")
+
+ @patch("unity_initiator.actions.submit_dag_by_id.httpx.post")
+ @patch("unity_initiator.actions.submit_dag_by_id.get_auth_headers")
+ def test_execute_with_basic_auth(self, mock_get_headers, mock_post):
+ """Test execution with basic authentication."""
+ # Mock successful response
+ mock_response = Mock()
+ mock_response.status_code = 200
+ mock_response.json.return_value = {"dag_run_id": "test-run"}
+ mock_post.return_value = mock_response
+
+ mock_get_headers.return_value = {"Authorization": "Basic dGVzdDp0ZXN0"}
+
+ params = {
+ "airflow_base_api_endpoint": "https://airflow.example.com",
+ "dag_id": "test_dag",
+ "airflow_username": "test",
+ "airflow_password": "test",
+ }
+
+ action = SubmitDagByID({"test": "data"}, {"info": "test"}, params)
+ result = action.execute()
+
+ assert result["success"] is True
+ assert result["response"]["dag_run_id"] == "test-run"
+ mock_get_headers.assert_called_once_with(
+ auth_type="basic", username="test", password="test"
+ )
+
+ @patch("unity_initiator.actions.submit_dag_by_id.httpx.post")
+ def test_execute_with_failed_response(self, mock_post):
+ """Test execution with failed response."""
+ # Mock failed response
+ mock_response = Mock()
+ mock_response.status_code = 400
+ mock_response.text = "Bad Request"
+ mock_post.return_value = mock_response
+
+ params = {
+ "airflow_base_api_endpoint": "https://airflow.example.com",
+ "dag_id": "test_dag",
+ "airflow_username": "test",
+ "airflow_password": "test",
+ }
+
+ action = SubmitDagByID({"test": "data"}, {"info": "test"}, params)
+ result = action.execute()
+
+ assert result["success"] is False
+ assert result["response"] == "Bad Request"
+
+ @patch("unity_initiator.actions.submit_dag_by_id.OAuth2Manager")
+ def test_get_auth_token_with_oauth2_credentials(self, mock_oauth2_manager_class):
+ """Test token fetching with OAuth2 credentials."""
+ # Mock OAuth2Manager instance
+ mock_manager = Mock()
+ mock_manager.get_valid_token.return_value = "oauth2-token-123"
+ mock_oauth2_manager_class.return_value = mock_manager
+
+ params = {
+ "oauth2_cognito_domain": "test.auth.us-west-2.amazoncognito.com",
+ "oauth2_client_id": "test-oauth2-client-id",
+ "oauth2_redirect_uri": "https://example.com/callback",
+ "oauth2_scope": "openid email profile",
+ "oauth2_region": "us-west-2",
+ }
+
+ action = SubmitDagByID({}, {}, params)
+ token = action._get_auth_token()
+
+ assert token == "oauth2-token-123"
+ mock_oauth2_manager_class.assert_called_once_with(
+ cognito_domain="test.auth.us-west-2.amazoncognito.com",
+ client_id="test-oauth2-client-id",
+ redirect_uri="https://example.com/callback",
+ scope="openid email profile",
+ region="us-west-2",
+ verify_ssl=True,
+ )
+ mock_manager.get_valid_token.assert_called_once()
+
+ @patch("unity_initiator.actions.submit_dag_by_id.OAuth2Manager")
+ def test_get_auth_token_oauth2_priority_over_cognito(
+ self, mock_oauth2_manager_class
+ ):
+ """Test that OAuth2 credentials take priority over Cognito credentials."""
+ # Mock OAuth2Manager instance
+ mock_manager = Mock()
+ mock_manager.get_valid_token.return_value = "oauth2-token-123"
+ mock_oauth2_manager_class.return_value = mock_manager
+
+ params = {
+ # OAuth2 credentials
+ "oauth2_cognito_domain": "test.auth.us-west-2.amazoncognito.com",
+ "oauth2_client_id": "test-oauth2-client-id",
+ "oauth2_redirect_uri": "https://example.com/callback",
+ # Cognito credentials (should be ignored)
+ "unity_username": "testuser",
+ "unity_password": "testpass",
+ "unity_client_id": "test-client-id",
+ }
+
+ action = SubmitDagByID({}, {}, params)
+ token = action._get_auth_token()
+
+ assert token == "oauth2-token-123"
+ # OAuth2Manager should be called, not TokenManager
+ mock_oauth2_manager_class.assert_called_once()
+
+ @patch("unity_initiator.actions.submit_dag_by_id.OAuth2Manager")
+ def test_get_auth_token_with_oauth2_credentials_verify_ssl_false(
+ self, mock_oauth2_manager_class
+ ):
+ """Test token fetching with OAuth2 credentials and verify_ssl=False."""
+ # Mock OAuth2Manager instance
+ mock_manager = Mock()
+ mock_manager.get_valid_token.return_value = "oauth2-token-123"
+ mock_oauth2_manager_class.return_value = mock_manager
+
+ params = {
+ "oauth2_cognito_domain": "test.auth.us-west-2.amazoncognito.com",
+ "oauth2_client_id": "test-oauth2-client-id",
+ "oauth2_redirect_uri": "https://example.com/callback",
+ "oauth2_scope": "openid email profile",
+ "oauth2_region": "us-west-2",
+ "oauth2_verify_ssl": False,
+ }
+
+ action = SubmitDagByID({}, {}, params)
+ token = action._get_auth_token()
+
+ assert token == "oauth2-token-123"
+ mock_oauth2_manager_class.assert_called_once_with(
+ cognito_domain="test.auth.us-west-2.amazoncognito.com",
+ client_id="test-oauth2-client-id",
+ redirect_uri="https://example.com/callback",
+ scope="openid email profile",
+ region="us-west-2",
+ verify_ssl=False,
+ )
+ mock_manager.get_valid_token.assert_called_once()