diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 986fcc3..7a3c2c6 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,7 +1,7 @@ fail_fast: true repos: - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v4.6.0 + rev: v5.0.0 hooks: # Git style - id: check-merge-conflict @@ -9,14 +9,14 @@ repos: - id: trailing-whitespace - repo: https://github.com/pycqa/isort - rev: 5.13.2 + rev: 6.0.1 hooks: - id: isort args: ["--profile", "black", "--filter-files"] # Using this mirror lets us use mypyc-compiled black, which is about 2x faster - repo: https://github.com/psf/black-pre-commit-mirror - rev: 24.4.2 + rev: 25.1.0 hooks: - id: black # It is recommended to specify the latest version of Python @@ -27,24 +27,24 @@ repos: - repo: https://github.com/astral-sh/ruff-pre-commit # Ruff version. - rev: v0.4.5 + rev: v0.12.3 hooks: - id: ruff args: ["--ignore", "E501,E402"] - repo: https://github.com/PyCQA/bandit - rev: "1.7.8" # you must change this to newest version + rev: "1.8.6" # you must change this to newest version hooks: - id: bandit args: ["--severity-level=high", "--confidence-level=high"] - repo: https://github.com/PyCQA/prospector - rev: v1.10.3 + rev: v1.17.2 hooks: - id: prospector - repo: https://github.com/antonbabenko/pre-commit-terraform - rev: v1.90.0 # Get the latest from: https://github.com/antonbabenko/pre-commit-terraform/releases + rev: v1.99.5 # Get the latest from: https://github.com/antonbabenko/pre-commit-terraform/releases hooks: # Terraform Tests - id: terraform_fmt diff --git a/.prospector.yaml b/.prospector.yaml index 8995ea9..f76baae 100644 --- a/.prospector.yaml +++ b/.prospector.yaml @@ -10,6 +10,11 @@ pylint: disable: - import-error - django-not-available + - import-outside-toplevel + - no-else-return + - consider-using-sys-exit + - too-many-arguments + - too-many-positional-arguments options: max-line-length: 159 diff --git a/CHANGELOG.md b/CHANGELOG.md index 7151834..7747ec5 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,7 +5,32 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). -## [X.Y.Z] - 2022-MM-DD +## [0.0.2] - 2024-12-19 + +### Added + +- Cognito token authentication support with automatic refresh +- Bearer token authentication for enhanced security +- `TokenManager` class for token lifecycle management +- `fetch_cognito_token()` function for Cognito integration +- `get_auth_headers()` utility for authentication headers +- Comprehensive test coverage for authentication features +- Example scripts demonstrating token usage +- Updated documentation for authentication configuration + +### Changed + +- Enhanced `SubmitDagByID` action to support multiple authentication methods +- Added `httpx` dependency for modern HTTP client functionality +- Maintained backward compatibility with existing basic auth + +### Security + +- Replaced basic authentication with more secure Bearer token authentication +- Added automatic token refresh to prevent authentication failures +- Implemented token caching to reduce API calls to Cognito + +## [0.0.1] - 2022-MM-DD ### Added diff --git a/README.md b/README.md index fa41d3f..b685dc2 100644 --- a/README.md +++ b/README.md @@ -160,6 +160,86 @@ and a trigger event payload for a new file that was triggered: In this case, the router sees that the action is `submit_dag_by_id` and thus makes a REST call to SPS to submit the URL payload, payload info, and `on_success` parameters as a DAG run. If the evaulator, running now as a DAG in SPS instead of an AWS Lambda function, successfully evaluates that everything is ready for this input file, it can proceed to submit a DAG run for the `submit_nisar_l0a_te_dag` DAG in the underlying SPS. +### Authentication for Airflow DAG Submissions + +The `submit_dag_by_id` action supports multiple authentication methods for connecting to Airflow REST APIs. The authentication method is determined by the parameters provided in the router configuration: + +#### 1. Bearer Token Authentication (Recommended) +Use a direct bearer token for authentication. This is the most secure method: + +```yaml +actions: + - name: submit_dag_by_id + params: + dag_id: example_dag + airflow_base_api_endpoint: https://airflow.example.com/api/v1 + airflow_token: ${AIRFLOW_BEARER_TOKEN} # Bearer token +``` + +#### 2. OAuth2 Authentication (For Proxy Servers) +Use OAuth2 authorization code flow for proxy authentication: + +```yaml +actions: + - name: submit_dag_by_id + params: + dag_id: example_dag + airflow_base_api_endpoint: https://proxy.example.com/api/v1 + oauth2_cognito_domain: your-domain.auth.us-west-2.amazoncognito.com + oauth2_client_id: your-oauth2-client-id + oauth2_redirect_uri: https://your-app.com/callback + oauth2_scope: openid email profile # Optional, defaults to "openid email profile" + oauth2_region: us-west-2 # Optional, defaults to us-west-2 + oauth2_verify_ssl: true # Optional, defaults to true for security +``` + +**OAuth2 Flow Setup**: +1. Use the provided `oauth2_token_init.py` script to initialize tokens +2. The script will guide you through the authorization flow +3. Tokens are automatically refreshed when needed + +#### 3. Cognito Token Authentication +Use Unity Cognito credentials to automatically fetch and refresh tokens: + +```yaml +actions: + - name: submit_dag_by_id + params: + dag_id: example_dag + airflow_base_api_endpoint: https://airflow.example.com/api/v1 + unity_username: ${UNITY_USERNAME} + unity_password: ${UNITY_PASSWORD} + unity_client_id: ${UNITY_CLIENT_ID} + unity_region: us-west-2 # Optional, defaults to us-west-2 +``` + +#### 4. Basic Authentication (Legacy) +Use username/password for basic authentication (less secure): + +```yaml +actions: + - name: submit_dag_by_id + params: + dag_id: example_dag + airflow_base_api_endpoint: https://airflow.example.com/api/v1 + airflow_username: ${AIRFLOW_USERNAME} + airflow_password: ${AIRFLOW_PASSWORD} +``` + +#### Authentication Priority +The system will use authentication in this order: +1. **Bearer token** (if `airflow_token` is provided) +2. **OAuth2 token** (if OAuth2 credentials are provided) +3. **Cognito token** (if Unity credentials are provided) +4. **Basic auth** (if username/password are provided) +5. **No authentication** (if no credentials are provided) + +#### Token Management +When using Cognito authentication: +- Tokens are automatically cached and refreshed 5 minutes before expiration +- Failed token refresh attempts fall back to credential-based fetching +- No manual token management required + + ## Requirements | Name | Version | @@ -37,4 +37,4 @@ No modules. | Name | Description | |------|-------------| | [centralized\_log\_group\_name](#output\_centralized\_log\_group\_name) | The name of the centralized log group | - + diff --git a/terraform-unity/evaluators/sns-sqs-lambda/README.md b/terraform-unity/evaluators/sns-sqs-lambda/README.md index 9d1aad5..8b8f2f8 100644 --- a/terraform-unity/evaluators/sns-sqs-lambda/README.md +++ b/terraform-unity/evaluators/sns-sqs-lambda/README.md @@ -1,6 +1,6 @@ # sns_sqs_lambda - + ## Requirements | Name | Version | @@ -62,4 +62,4 @@ No modules. | Name | Description | |------|-------------| | [evaluator\_topic\_arn](#output\_evaluator\_topic\_arn) | The ARN of the evaluator SNS topic | - + diff --git a/terraform-unity/initiator/README.md b/terraform-unity/initiator/README.md index 64d43b7..4d79947 100644 --- a/terraform-unity/initiator/README.md +++ b/terraform-unity/initiator/README.md @@ -1,6 +1,6 @@ # terraform-unity - + ## Requirements | Name | Version | @@ -60,4 +60,4 @@ No modules. | Name | Description | |------|-------------| | [initiator\_topic\_arn](#output\_initiator\_topic\_arn) | The ARN of the initiator SNS topic | - + diff --git a/terraform-unity/triggers/cmr-query/README.md b/terraform-unity/triggers/cmr-query/README.md index 64511ff..2fd97bf 100644 --- a/terraform-unity/triggers/cmr-query/README.md +++ b/terraform-unity/triggers/cmr-query/README.md @@ -1,6 +1,6 @@ # scheduled_task - + ## Requirements | Name | Version | @@ -63,4 +63,4 @@ No modules. ## Outputs No outputs. - + diff --git a/terraform-unity/triggers/s3-bucket-notification/README.md b/terraform-unity/triggers/s3-bucket-notification/README.md index 14c7a4e..49f3961 100644 --- a/terraform-unity/triggers/s3-bucket-notification/README.md +++ b/terraform-unity/triggers/s3-bucket-notification/README.md @@ -1,6 +1,6 @@ # s3_bucket_notification - + ## Requirements | Name | Version | @@ -38,4 +38,4 @@ No modules. ## Outputs No outputs. - + diff --git a/terraform-unity/triggers/scheduled-task-instrumented/README.md b/terraform-unity/triggers/scheduled-task-instrumented/README.md index ecdd896..52f16e3 100644 --- a/terraform-unity/triggers/scheduled-task-instrumented/README.md +++ b/terraform-unity/triggers/scheduled-task-instrumented/README.md @@ -1,6 +1,6 @@ # scheduled_task - + ## Requirements | Name | Version | @@ -54,4 +54,4 @@ No modules. ## Outputs No outputs. - + diff --git a/terraform-unity/triggers/scheduled-task/README.md b/terraform-unity/triggers/scheduled-task/README.md index 163c4bd..31c9871 100644 --- a/terraform-unity/triggers/scheduled-task/README.md +++ b/terraform-unity/triggers/scheduled-task/README.md @@ -1,6 +1,6 @@ # scheduled_task - + ## Requirements | Name | Version | @@ -49,4 +49,4 @@ No modules. ## Outputs No outputs. - + diff --git a/tests/resources/test_router_with_auth.yaml b/tests/resources/test_router_with_auth.yaml new file mode 100644 index 0000000..862506d --- /dev/null +++ b/tests/resources/test_router_with_auth.yaml @@ -0,0 +1,55 @@ +initiator_config: + name: test config with authentication examples + + payload_type: + url: + # Test with Bearer token authentication + - regexes: + - '/(?Pbearer_test_(?P\d{10})\.json)$' + evaluators: + - name: eval_bearer_auth + actions: + - name: submit_dag_by_id + params: + dag_id: test_bearer_auth_dag + airflow_base_api_endpoint: https://airflow.example.com/api/v1 + airflow_token: test-bearer-token-123 + + # Test with Cognito authentication + - regexes: + - '/(?Pcognito_test_(?P\d{10})\.json)$' + evaluators: + - name: eval_cognito_auth + actions: + - name: submit_dag_by_id + params: + dag_id: test_cognito_auth_dag + airflow_base_api_endpoint: https://airflow.example.com/api/v1 + unity_username: testuser + unity_password: testpass + unity_client_id: test-client-id + unity_region: us-west-2 + + # Test with Basic authentication + - regexes: + - '/(?Pbasic_test_(?P\d{10})\.json)$' + evaluators: + - name: eval_basic_auth + actions: + - name: submit_dag_by_id + params: + dag_id: test_basic_auth_dag + airflow_base_api_endpoint: https://airflow.example.com/api/v1 + airflow_username: testuser + airflow_password: testpass + + # Test with no authentication + - regexes: + - '/(?Pno_auth_test_(?P\d{10})\.json)$' + evaluators: + - name: eval_no_auth + actions: + - name: submit_dag_by_id + params: + dag_id: test_no_auth_dag + airflow_base_api_endpoint: https://airflow.example.com/api/v1 \ No newline at end of file diff --git a/tests/test_auth_utils.py b/tests/test_auth_utils.py new file mode 100644 index 0000000..04e6e78 --- /dev/null +++ b/tests/test_auth_utils.py @@ -0,0 +1,272 @@ +import time +from unittest.mock import Mock, patch + +from unity_initiator.utils.auth_utils import ( + TokenInfo, + TokenManager, + fetch_cognito_token, + get_auth_headers, +) + + +class TestAuthUtils: + """Test authentication utilities.""" + + def test_get_auth_headers_basic(self): + """Test basic authentication headers.""" + headers = get_auth_headers( + auth_type="basic", username="testuser", password="testpass" + ) + + assert "Authorization" in headers + assert headers["Authorization"].startswith("Basic ") + assert headers["Content-Type"] == "application/json" + assert headers["Accept"] == "application/json" + + def test_get_auth_headers_bearer(self): + """Test bearer token authentication headers.""" + token = "test-token-123" + headers = get_auth_headers(auth_type="bearer", token=token) + + assert "Authorization" in headers + assert headers["Authorization"] == f"Bearer {token}" + assert headers["Content-Type"] == "application/json" + assert headers["Accept"] == "application/json" + + def test_get_auth_headers_no_auth(self): + """Test headers without authentication.""" + headers = get_auth_headers() + + assert "Authorization" not in headers + assert headers["Content-Type"] == "application/json" + assert headers["Accept"] == "application/json" + + @patch("unity_initiator.utils.auth_utils.TokenManager") + def test_fetch_cognito_token_success(self, mock_token_manager_class): + """Test successful Cognito token fetching.""" + # Mock TokenManager instance + mock_manager = Mock() + mock_manager.get_valid_token.return_value = "test-access-token-123" + mock_token_manager_class.return_value = mock_manager + + token = fetch_cognito_token( + username="testuser", password="testpass", client_id="test-client-id" + ) + + assert token == "test-access-token-123" + mock_token_manager_class.assert_called_once_with( + "testuser", + "testpass", + "test-client-id", + "us-west-2", + ) + mock_manager.get_valid_token.assert_called_once() + + @patch("unity_initiator.utils.auth_utils.TokenManager") + def test_fetch_cognito_token_no_auth_result(self, mock_token_manager_class): + """Test Cognito token fetching with no authentication result.""" + # Mock TokenManager instance that returns None + mock_manager = Mock() + mock_manager.get_valid_token.return_value = None + mock_token_manager_class.return_value = mock_manager + + token = fetch_cognito_token( + username="testuser", password="testpass", client_id="test-client-id" + ) + + assert token is None + + @patch("unity_initiator.utils.auth_utils.TokenManager") + def test_fetch_cognito_token_http_error(self, mock_token_manager_class): + """Test Cognito token fetching with HTTP error.""" + # Mock TokenManager instance that returns None due to error + mock_manager = Mock() + mock_manager.get_valid_token.return_value = None + mock_token_manager_class.return_value = mock_manager + + token = fetch_cognito_token( + username="testuser", password="testpass", client_id="test-client-id" + ) + + assert token is None + + @patch("unity_initiator.utils.auth_utils.TokenManager") + def test_fetch_cognito_token_request_error(self, mock_token_manager_class): + """Test Cognito token fetching with request error.""" + # Mock TokenManager instance that returns None due to error + mock_manager = Mock() + mock_manager.get_valid_token.return_value = None + mock_token_manager_class.return_value = mock_manager + + token = fetch_cognito_token( + username="testuser", password="testpass", client_id="test-client-id" + ) + + assert token is None + + +class TestTokenManager: + """Test TokenManager class.""" + + def test_token_manager_init(self): + """Test TokenManager initialization.""" + manager = TokenManager("user", "pass", "client-id", "us-west-2") + + assert manager.username == "user" + assert manager.password == "pass" + assert manager.client_id == "client-id" + assert manager.region == "us-west-2" + assert manager._token_cache is None + assert manager._refresh_buffer == 300 + + def test_token_info_dataclass(self): + """Test TokenInfo dataclass.""" + token_info = TokenInfo( + access_token="test-token", + expires_at=time.time() + 3600, + refresh_token="refresh-token", + ) + + assert token_info.access_token == "test-token" + assert token_info.expires_at > time.time() + assert token_info.refresh_token == "refresh-token" + + @patch("unity_initiator.utils.auth_utils.httpx.Client") + def test_get_valid_token_first_time(self, mock_client_class): + """Test getting token for the first time.""" + # Mock successful response + mock_response = Mock() + mock_response.json.return_value = { + "AuthenticationResult": { + "AccessToken": "test-token-123", + "ExpiresIn": 3600, + "RefreshToken": "refresh-token-123", + } + } + mock_response.raise_for_status.return_value = None + + mock_client = Mock() + mock_client.post.return_value = mock_response + mock_client_class.return_value.__enter__.return_value = mock_client + mock_client_class.return_value.__exit__.return_value = None + + manager = TokenManager("user", "pass", "client-id") + token = manager.get_valid_token() + + assert token == "test-token-123" + assert manager._token_cache is not None + assert manager._token_cache.access_token == "test-token-123" + assert manager._token_cache.refresh_token == "refresh-token-123" + + def test_get_valid_token_cached_valid(self): + """Test getting token when cached token is still valid.""" + # Create a token that expires in 1 hour + expires_at = time.time() + 3600 + token_info = TokenInfo( + access_token="cached-token", + expires_at=expires_at, + refresh_token="refresh-token", + ) + + manager = TokenManager("user", "pass", "client-id") + manager._token_cache = token_info + + token = manager.get_valid_token() + + assert token == "cached-token" + + def test_get_valid_token_cached_expired(self): + """Test getting token when cached token is expired.""" + # Create a token that expired 1 hour ago + expires_at = time.time() - 3600 + token_info = TokenInfo( + access_token="expired-token", + expires_at=expires_at, + refresh_token="refresh-token", + ) + + manager = TokenManager("user", "pass", "client-id") + manager._token_cache = token_info + + # Mock the _fetch_new_token method + with patch.object(manager, "_fetch_new_token", return_value="new-token"): + token = manager.get_valid_token() + + assert token == "new-token" + + def test_get_valid_token_cached_expiring_soon(self): + """Test getting token when cached token is expiring soon.""" + # Create a token that expires in 2 minutes (less than 5-minute buffer) + expires_at = time.time() + 120 + token_info = TokenInfo( + access_token="expiring-token", + expires_at=expires_at, + refresh_token="refresh-token", + ) + + manager = TokenManager("user", "pass", "client-id") + manager._token_cache = token_info + + # Mock the _fetch_new_token method + with patch.object(manager, "_fetch_new_token", return_value="new-token"): + token = manager.get_valid_token() + + assert token == "new-token" + + @patch("unity_initiator.utils.auth_utils.httpx.Client") + def test_refresh_token_success(self, mock_client_class): + """Test successful token refresh.""" + # Mock successful refresh response + mock_response = Mock() + mock_response.json.return_value = { + "AuthenticationResult": { + "AccessToken": "refreshed-token", + "ExpiresIn": 3600, + } + } + mock_response.raise_for_status.return_value = None + + mock_client = Mock() + mock_client.post.return_value = mock_response + mock_client_class.return_value.__enter__.return_value = mock_client + mock_client_class.return_value.__exit__.return_value = None + + manager = TokenManager("user", "pass", "client-id") + manager._token_cache = TokenInfo( + access_token="old-token", + expires_at=time.time() - 3600, # Expired + refresh_token="refresh-token", + ) + + token = manager._refresh_token() + + assert token == "refreshed-token" + assert manager._token_cache.access_token == "refreshed-token" + + def test_refresh_token_no_refresh_token(self): + """Test refresh when no refresh token is available.""" + manager = TokenManager("user", "pass", "client-id") + manager._token_cache = TokenInfo( + access_token="old-token", + expires_at=time.time() - 3600, # Expired + refresh_token=None, # No refresh token + ) + + # Mock the _fetch_new_token method + with patch.object(manager, "_fetch_new_token", return_value="new-token"): + token = manager._refresh_token() + + assert token == "new-token" + + def test_clear_cache(self): + """Test clearing the token cache.""" + manager = TokenManager("user", "pass", "client-id") + manager._token_cache = TokenInfo( + access_token="test-token", + expires_at=time.time() + 3600, + refresh_token="refresh-token", + ) + + manager.clear_cache() + + assert manager._token_cache is None diff --git a/tests/test_oauth2_utils.py b/tests/test_oauth2_utils.py new file mode 100644 index 0000000..fd8db39 --- /dev/null +++ b/tests/test_oauth2_utils.py @@ -0,0 +1,353 @@ +""" +Tests for OAuth2 authentication utilities. +""" + +import time +from unittest.mock import Mock, patch +from urllib.parse import parse_qs + +import httpx +import pytest + +from unity_initiator.utils.oauth2_utils import ( + OAuth2Manager, + extract_authorization_code_from_url, + get_oauth2_headers, +) + + +class TestOAuth2Manager: + """Test OAuth2Manager class.""" + + def test_init(self): + """Test OAuth2Manager initialization.""" + manager = OAuth2Manager( + cognito_domain="test.auth.us-west-2.amazoncognito.com", + client_id="test-client-id", + redirect_uri="https://example.com/callback", + ) + + assert manager.cognito_domain == "test.auth.us-west-2.amazoncognito.com" + assert manager.client_id == "test-client-id" + assert manager.redirect_uri == "https://example.com/callback" + assert manager.scope == "openid email profile" + assert manager.region == "us-west-2" + assert manager.verify_ssl is True + assert ( + manager.auth_endpoint + == "https://test.auth.us-west-2.amazoncognito.com/oauth2/authorize" + ) + assert ( + manager.token_endpoint + == "https://test.auth.us-west-2.amazoncognito.com/oauth2/token" + ) + + def test_init_with_verify_ssl_false(self): + """Test initialization with verify_ssl=False.""" + manager = OAuth2Manager( + cognito_domain="test.auth.us-west-2.amazoncognito.com", + client_id="test-client-id", + redirect_uri="https://example.com/callback", + verify_ssl=False, + ) + + assert manager.verify_ssl is False + + def test_get_authorization_url(self): + """Test authorization URL generation.""" + manager = OAuth2Manager( + cognito_domain="test.auth.us-west-2.amazoncognito.com", + client_id="test-client-id", + redirect_uri="https://example.com/callback", + ) + + auth_url = manager.get_authorization_url() + + # Parse URL and check parameters + parsed = httpx.URL(auth_url) + params = parse_qs( + parsed.query.decode() + if hasattr(parsed.query, "decode") + else str(parsed.query) + ) + + assert parsed.scheme == "https" + assert parsed.host == "test.auth.us-west-2.amazoncognito.com" + assert parsed.path == "/oauth2/authorize" + assert params["response_type"] == ["code"] + assert params["client_id"] == ["test-client-id"] + assert params["redirect_uri"] == ["https://example.com/callback"] + assert params["scope"] == ["openid email profile"] + assert "state" in params + assert "nonce" in params + + def test_get_authorization_url_with_custom_state(self): + """Test authorization URL generation with custom state.""" + manager = OAuth2Manager( + cognito_domain="test.auth.us-west-2.amazoncognito.com", + client_id="test-client-id", + redirect_uri="https://example.com/callback", + ) + + auth_url = manager.get_authorization_url(state="custom-state") + + parsed = httpx.URL(auth_url) + params = parse_qs( + parsed.query.decode() + if hasattr(parsed.query, "decode") + else str(parsed.query) + ) + + assert params["state"] == ["custom-state"] + + @patch("unity_initiator.utils.oauth2_utils.httpx.Client") + def test_exchange_code_for_token_success(self, mock_client_class): + """Test successful token exchange.""" + # Mock response + mock_response = Mock() + mock_response.json.return_value = { + "access_token": "test-access-token", + "refresh_token": "test-refresh-token", + "expires_in": 3600, + "token_type": "Bearer", + } + mock_response.raise_for_status.return_value = None + + # Mock client + mock_client = Mock() + mock_client.post.return_value = mock_response + mock_client_class.return_value.__enter__.return_value = mock_client + + manager = OAuth2Manager( + cognito_domain="test.auth.us-west-2.amazoncognito.com", + client_id="test-client-id", + redirect_uri="https://example.com/callback", + ) + + result = manager.exchange_code_for_token("test-auth-code") + + # Verify client was called correctly + mock_client.post.assert_called_once() + call_args = mock_client.post.call_args + assert ( + call_args[0][0] + == "https://test.auth.us-west-2.amazoncognito.com/oauth2/token" + ) + + # Verify form data + form_data = call_args[1]["data"] + assert form_data["grant_type"] == "authorization_code" + assert form_data["client_id"] == "test-client-id" + assert form_data["code"] == "test-auth-code" + assert form_data["redirect_uri"] == "https://example.com/callback" + + # Verify headers + headers = call_args[1]["headers"] + assert headers["Content-Type"] == "application/x-www-form-urlencoded" + + # Verify result + assert result["access_token"] == "test-access-token" + assert result["refresh_token"] == "test-refresh-token" + assert result["expires_in"] == 3600 + + # Verify tokens were stored + assert manager._access_token == "test-access-token" + assert manager._refresh_token == "test-refresh-token" + assert manager._token_expires_at is not None + + @patch("unity_initiator.utils.oauth2_utils.httpx.Client") + def test_exchange_code_for_token_http_error(self, mock_client_class): + """Test token exchange with HTTP error.""" + # Mock HTTP error response + mock_response = Mock() + mock_response.text = "Invalid authorization code" + mock_response.raise_for_status.side_effect = httpx.HTTPStatusError( + "400 Bad Request", request=Mock(), response=mock_response + ) + + # Mock client + mock_client = Mock() + mock_client.post.return_value = mock_response + mock_client_class.return_value.__enter__.return_value = mock_client + + manager = OAuth2Manager( + cognito_domain="test.auth.us-west-2.amazoncognito.com", + client_id="test-client-id", + redirect_uri="https://example.com/callback", + ) + + with pytest.raises(httpx.HTTPStatusError): + manager.exchange_code_for_token("invalid-code") + + @patch("unity_initiator.utils.oauth2_utils.httpx.Client") + def test_refresh_access_token_success(self, mock_client_class): + """Test successful token refresh.""" + # Mock response + mock_response = Mock() + mock_response.json.return_value = { + "access_token": "new-access-token", + "expires_in": 3600, + "token_type": "Bearer", + } + mock_response.raise_for_status.return_value = None + + # Mock client + mock_client = Mock() + mock_client.post.return_value = mock_response + mock_client_class.return_value.__enter__.return_value = mock_client + + manager = OAuth2Manager( + cognito_domain="test.auth.us-west-2.amazoncognito.com", + client_id="test-client-id", + redirect_uri="https://example.com/callback", + ) + + # Set up existing refresh token + manager._refresh_token = "existing-refresh-token" + + result = manager.refresh_access_token() + + # Verify client was called correctly + mock_client.post.assert_called_once() + call_args = mock_client.post.call_args + + # Verify form data + form_data = call_args[1]["data"] + assert form_data["grant_type"] == "refresh_token" + assert form_data["client_id"] == "test-client-id" + assert form_data["refresh_token"] == "existing-refresh-token" + + # Verify result + assert result == "new-access-token" + assert manager._access_token == "new-access-token" + + def test_refresh_access_token_no_refresh_token(self): + """Test token refresh when no refresh token is available.""" + manager = OAuth2Manager( + cognito_domain="test.auth.us-west-2.amazoncognito.com", + client_id="test-client-id", + redirect_uri="https://example.com/callback", + ) + + result = manager.refresh_access_token() + assert result is None + + def test_get_valid_token_with_valid_token(self): + """Test getting valid token when token is still valid.""" + manager = OAuth2Manager( + cognito_domain="test.auth.us-west-2.amazoncognito.com", + client_id="test-client-id", + redirect_uri="https://example.com/callback", + ) + + # Set up valid token + manager._access_token = "valid-token" + manager._token_expires_at = time.time() + 3600 # Expires in 1 hour + + result = manager.get_valid_token() + assert result == "valid-token" + + @patch.object(OAuth2Manager, "refresh_access_token") + def test_get_valid_token_with_expired_token(self, mock_refresh): + """Test getting valid token when token is expired.""" + mock_refresh.return_value = "refreshed-token" + + manager = OAuth2Manager( + cognito_domain="test.auth.us-west-2.amazoncognito.com", + client_id="test-client-id", + redirect_uri="https://example.com/callback", + ) + + # Set up expired token + manager._access_token = "expired-token" + manager._refresh_token = "refresh-token" + manager._token_expires_at = time.time() - 60 # Expired 1 minute ago + + result = manager.get_valid_token() + + assert result == "refreshed-token" + mock_refresh.assert_called_once() + + def test_get_valid_token_no_tokens(self): + """Test getting valid token when no tokens are available.""" + manager = OAuth2Manager( + cognito_domain="test.auth.us-west-2.amazoncognito.com", + client_id="test-client-id", + redirect_uri="https://example.com/callback", + ) + + result = manager.get_valid_token() + assert result is None + + def test_clear_tokens(self): + """Test clearing stored tokens.""" + manager = OAuth2Manager( + cognito_domain="test.auth.us-west-2.amazoncognito.com", + client_id="test-client-id", + redirect_uri="https://example.com/callback", + ) + + # Set up tokens + manager._access_token = "test-token" + manager._refresh_token = "test-refresh" + manager._token_expires_at = time.time() + 3600 + + manager.clear_tokens() + + assert manager._access_token is None + assert manager._refresh_token is None + assert manager._token_expires_at is None + + +class TestExtractAuthorizationCodeFromUrl: + """Test extract_authorization_code_from_url function.""" + + def test_extract_authorization_code_success(self): + """Test successful authorization code extraction.""" + url = "https://example.com/callback?code=test-auth-code&state=test-state" + + result = extract_authorization_code_from_url(url) + + assert result == "test-auth-code" + + def test_extract_authorization_code_with_error(self): + """Test authorization code extraction with OAuth2 error.""" + url = "https://example.com/callback?error=access_denied&error_description=User+cancelled+authorization" + + result = extract_authorization_code_from_url(url) + + assert result is None + + def test_extract_authorization_code_no_code(self): + """Test authorization code extraction when no code is present.""" + url = "https://example.com/callback?state=test-state" + + result = extract_authorization_code_from_url(url) + + assert result is None + + def test_extract_authorization_code_invalid_url(self): + """Test authorization code extraction with invalid URL.""" + url = "not-a-valid-url" + + result = extract_authorization_code_from_url(url) + + assert result is None + + +class TestGetOAuth2Headers: + """Test get_oauth2_headers function.""" + + def test_get_oauth2_headers(self): + """Test OAuth2 headers generation.""" + token = "test-access-token" + + headers = get_oauth2_headers(token) + + expected_headers = { + "Authorization": "Bearer test-access-token", + "Content-Type": "application/json", + "Accept": "application/json", + } + + assert headers == expected_headers diff --git a/tests/test_submit_dag_by_id.py b/tests/test_submit_dag_by_id.py new file mode 100644 index 0000000..12146a7 --- /dev/null +++ b/tests/test_submit_dag_by_id.py @@ -0,0 +1,229 @@ +from unittest.mock import Mock, patch + +from unity_initiator.actions.submit_dag_by_id import SubmitDagByID + + +class TestSubmitDagByID: + """Test SubmitDagByID action.""" + + def test_init(self): + """Test initialization.""" + payload = {"test": "data"} + payload_info = {"info": "test"} + params = {"dag_id": "test_dag"} + + action = SubmitDagByID(payload, payload_info, params) + + assert action._payload == payload + assert action._payload_info == payload_info + assert action._params == params + + @patch("unity_initiator.actions.submit_dag_by_id.TokenManager") + def test_get_auth_token_with_cognito_credentials(self, mock_token_manager_class): + """Test token fetching with Cognito credentials.""" + # Mock TokenManager instance + mock_manager = Mock() + mock_manager.get_valid_token.return_value = "test-token-123" + mock_token_manager_class.return_value = mock_manager + + params = { + "unity_username": "testuser", + "unity_password": "testpass", + "unity_client_id": "test-client-id", + } + + action = SubmitDagByID({}, {}, params) + token = action._get_auth_token() + + assert token == "test-token-123" + mock_token_manager_class.assert_called_once_with( + username="testuser", + password="testpass", + client_id="test-client-id", + region="us-west-2", + ) + mock_manager.get_valid_token.assert_called_once() + + def test_get_auth_token_with_direct_token(self): + """Test token fetching with direct token.""" + params = {"airflow_token": "direct-token-123"} + + action = SubmitDagByID({}, {}, params) + token = action._get_auth_token() + + assert token == "direct-token-123" + + def test_get_auth_token_no_credentials(self): + """Test token fetching with no credentials.""" + params = {} + + action = SubmitDagByID({}, {}, params) + token = action._get_auth_token() + + assert token is None + + @patch("unity_initiator.actions.submit_dag_by_id.httpx.post") + @patch("unity_initiator.actions.submit_dag_by_id.get_auth_headers") + def test_execute_with_bearer_token(self, mock_get_headers, mock_post): + """Test execution with Bearer token authentication.""" + # Mock successful response + mock_response = Mock() + mock_response.status_code = 200 + mock_response.json.return_value = {"dag_run_id": "test-run"} + mock_post.return_value = mock_response + + mock_get_headers.return_value = {"Authorization": "Bearer test-token"} + + params = { + "airflow_base_api_endpoint": "https://airflow.example.com", + "dag_id": "test_dag", + "airflow_token": "test-token", + } + + action = SubmitDagByID({"test": "data"}, {"info": "test"}, params) + result = action.execute() + + assert result["success"] is True + assert result["response"]["dag_run_id"] == "test-run" + mock_get_headers.assert_called_once_with(auth_type="bearer", token="test-token") + + @patch("unity_initiator.actions.submit_dag_by_id.httpx.post") + @patch("unity_initiator.actions.submit_dag_by_id.get_auth_headers") + def test_execute_with_basic_auth(self, mock_get_headers, mock_post): + """Test execution with basic authentication.""" + # Mock successful response + mock_response = Mock() + mock_response.status_code = 200 + mock_response.json.return_value = {"dag_run_id": "test-run"} + mock_post.return_value = mock_response + + mock_get_headers.return_value = {"Authorization": "Basic dGVzdDp0ZXN0"} + + params = { + "airflow_base_api_endpoint": "https://airflow.example.com", + "dag_id": "test_dag", + "airflow_username": "test", + "airflow_password": "test", + } + + action = SubmitDagByID({"test": "data"}, {"info": "test"}, params) + result = action.execute() + + assert result["success"] is True + assert result["response"]["dag_run_id"] == "test-run" + mock_get_headers.assert_called_once_with( + auth_type="basic", username="test", password="test" + ) + + @patch("unity_initiator.actions.submit_dag_by_id.httpx.post") + def test_execute_with_failed_response(self, mock_post): + """Test execution with failed response.""" + # Mock failed response + mock_response = Mock() + mock_response.status_code = 400 + mock_response.text = "Bad Request" + mock_post.return_value = mock_response + + params = { + "airflow_base_api_endpoint": "https://airflow.example.com", + "dag_id": "test_dag", + "airflow_username": "test", + "airflow_password": "test", + } + + action = SubmitDagByID({"test": "data"}, {"info": "test"}, params) + result = action.execute() + + assert result["success"] is False + assert result["response"] == "Bad Request" + + @patch("unity_initiator.actions.submit_dag_by_id.OAuth2Manager") + def test_get_auth_token_with_oauth2_credentials(self, mock_oauth2_manager_class): + """Test token fetching with OAuth2 credentials.""" + # Mock OAuth2Manager instance + mock_manager = Mock() + mock_manager.get_valid_token.return_value = "oauth2-token-123" + mock_oauth2_manager_class.return_value = mock_manager + + params = { + "oauth2_cognito_domain": "test.auth.us-west-2.amazoncognito.com", + "oauth2_client_id": "test-oauth2-client-id", + "oauth2_redirect_uri": "https://example.com/callback", + "oauth2_scope": "openid email profile", + "oauth2_region": "us-west-2", + } + + action = SubmitDagByID({}, {}, params) + token = action._get_auth_token() + + assert token == "oauth2-token-123" + mock_oauth2_manager_class.assert_called_once_with( + cognito_domain="test.auth.us-west-2.amazoncognito.com", + client_id="test-oauth2-client-id", + redirect_uri="https://example.com/callback", + scope="openid email profile", + region="us-west-2", + verify_ssl=True, + ) + mock_manager.get_valid_token.assert_called_once() + + @patch("unity_initiator.actions.submit_dag_by_id.OAuth2Manager") + def test_get_auth_token_oauth2_priority_over_cognito( + self, mock_oauth2_manager_class + ): + """Test that OAuth2 credentials take priority over Cognito credentials.""" + # Mock OAuth2Manager instance + mock_manager = Mock() + mock_manager.get_valid_token.return_value = "oauth2-token-123" + mock_oauth2_manager_class.return_value = mock_manager + + params = { + # OAuth2 credentials + "oauth2_cognito_domain": "test.auth.us-west-2.amazoncognito.com", + "oauth2_client_id": "test-oauth2-client-id", + "oauth2_redirect_uri": "https://example.com/callback", + # Cognito credentials (should be ignored) + "unity_username": "testuser", + "unity_password": "testpass", + "unity_client_id": "test-client-id", + } + + action = SubmitDagByID({}, {}, params) + token = action._get_auth_token() + + assert token == "oauth2-token-123" + # OAuth2Manager should be called, not TokenManager + mock_oauth2_manager_class.assert_called_once() + + @patch("unity_initiator.actions.submit_dag_by_id.OAuth2Manager") + def test_get_auth_token_with_oauth2_credentials_verify_ssl_false( + self, mock_oauth2_manager_class + ): + """Test token fetching with OAuth2 credentials and verify_ssl=False.""" + # Mock OAuth2Manager instance + mock_manager = Mock() + mock_manager.get_valid_token.return_value = "oauth2-token-123" + mock_oauth2_manager_class.return_value = mock_manager + + params = { + "oauth2_cognito_domain": "test.auth.us-west-2.amazoncognito.com", + "oauth2_client_id": "test-oauth2-client-id", + "oauth2_redirect_uri": "https://example.com/callback", + "oauth2_scope": "openid email profile", + "oauth2_region": "us-west-2", + "oauth2_verify_ssl": False, + } + + action = SubmitDagByID({}, {}, params) + token = action._get_auth_token() + + assert token == "oauth2-token-123" + mock_oauth2_manager_class.assert_called_once_with( + cognito_domain="test.auth.us-west-2.amazoncognito.com", + client_id="test-oauth2-client-id", + redirect_uri="https://example.com/callback", + scope="openid email profile", + region="us-west-2", + verify_ssl=False, + ) + mock_manager.get_valid_token.assert_called_once()