diff --git a/mpesakit/http_client/http_client.py b/mpesakit/http_client/http_client.py index a10c226..4e50f5d 100644 --- a/mpesakit/http_client/http_client.py +++ b/mpesakit/http_client/http_client.py @@ -1,11 +1,19 @@ """http_client.py: Defines an abstract base HTTP client class for making HTTP requests. +Error handling for both synchronous ans Asynchronous clients + Provides a reusable interface for GET and POST requests. """ from typing import Dict, Any, Optional from abc import ABC, abstractmethod +from mpesakit.errors import MpesaApiException, MpesaError +from tenacity import ( + RetryCallState, + retry_if_exception_type, +) +import httpx class HttpClient(ABC): """Abstract base HTTP client for making GET and POST requests.""" @@ -51,3 +59,76 @@ async def get( ) -> Dict[str, Any]: """Sends an asynchronous GET request.""" pass + +def handle_request_error(response: httpx.Response): + """Handles non-successful HTTP responses. + + This function is now responsible for converting HTTP status codes + and JSON parsing errors into MpesaApiException. + """ + if response.is_success: + return + try: + response_data = response.json() + except ValueError: + response_data = {"errorMessage": response.text.strip() or ""} + + error_message = response_data.get("errorMessage", "") + raise MpesaApiException( + MpesaError( + error_code=f"HTTP_{response.status_code}", + error_message=error_message, + status_code=response.status_code, + raw_response=response_data, + ) + ) + + +def handle_retry_exception(retry_state: RetryCallState): + """Custom hook to handle exceptions after all retries fail. + + It raises a custom MpesaApiException with the appropriate error code. + """ + if retry_state.outcome: + exception = retry_state.outcome.exception() + + if isinstance(exception, httpx.TimeoutException): + raise MpesaApiException( + MpesaError(error_code="REQUEST_TIMEOUT", error_message=str(exception)) + ) from exception + elif isinstance(exception, httpx.ConnectError): + raise MpesaApiException( + MpesaError(error_code="CONNECTION_ERROR", error_message=str(exception)) + ) from exception + + raise MpesaApiException( + MpesaError(error_code="REQUEST_FAILED", error_message=str(exception)) + ) from exception + + raise MpesaApiException( + MpesaError( + error_code="REQUEST_FAILED", + error_message="An unknown retry error occurred.", + ) + ) + + +def retry_enabled(enabled: bool): + """Factory function to conditionally enable retries. + + Args: + enabled (bool): Whether to enable retry logic. + + Returns: + A retry condition function. + """ + base_retry = retry_if_exception_type( + httpx.TimeoutException + ) | retry_if_exception_type(httpx.ConnectError) + + def _retry(retry_state): + if not enabled: + return False + return base_retry(retry_state) + + return _retry diff --git a/mpesakit/http_client/mpesa_async_http_client.py b/mpesakit/http_client/mpesa_async_http_client.py index 55a5d41..cbaf068 100644 --- a/mpesakit/http_client/mpesa_async_http_client.py +++ b/mpesakit/http_client/mpesa_async_http_client.py @@ -1,12 +1,27 @@ """MpesaAsyncHttpClient: An asynchronous client for making HTTP requests to the M-Pesa API.""" -from typing import Dict, Any, Optional +import logging +import uuid +from typing import Dict, Any, Optional , MutableMapping import httpx from mpesakit.errors import MpesaError, MpesaApiException from .http_client import AsyncHttpClient +from tenacity import( + retry, + retry_if_exception_type, + stop_after_attempt, + wait_random_exponential, + before_sleep_log, +) +from .http_client import retry_enabled, handle_request_error, handle_retry_exception +logger = logging.getLogger(__name__) + +mpesa_retry_condition = retry_if_exception_type( + (httpx.ConnectError, httpx.ConnectTimeout,httpx.TimeoutException ,httpx.ReadTimeout) +) class MpesaAsyncHttpClient(AsyncHttpClient): """An asynchronous client for making HTTP requests to the M-Pesa API. @@ -38,121 +53,171 @@ async def __aexit__(self, exc_type, exc_val, exc_tb): await self._client.aclose() + async def _raw_post( + self, + url: str, + json: Dict[str, Any], + headers: MutableMapping[str, str], + timeout: int = 10, + ) -> httpx.Response: + """Low-level asynchronous POST request - may raise httpx exceptions.""" + return await self._client.post( + url, + json=json, + headers=headers, + timeout=timeout, + ) + + @retry( + retry=mpesa_retry_condition, + wait=wait_random_exponential(multiplier=5, max=8), + stop=stop_after_attempt(3), + retry_error_callback=handle_retry_exception, + before_sleep=before_sleep_log(logger, logging.WARNING), + reraise=True, + ) + async def _retryable_post( + self, + url: str, + json: Dict[str, Any], + headers: MutableMapping[str, str], + timeout: int = 10, +) -> httpx.Response: + return await self._raw_post( + url, + json, + headers, + timeout, + ) + async def post( - self, url: str, json: Dict[str, Any], headers: Dict[str, str] + self, + url: str, + json: Dict[str, Any], + headers: Optional[Dict[str, str]] = None, + timeout: int = 10, + idempotent:bool = False, ) -> Dict[str, Any]: - """Sends an asynchronous POST request to the M-Pesa API.""" - try: + """Sends a asynchronous POST request to the M-Pesa API. - response = await self._client.post( - url, json=json, headers=headers, timeout=10 - ) - - - try: - response_data = response.json() - except ValueError: - response_data = {"errorMessage": response.text.strip() or ""} - - if not response.is_success: - error_message = response_data.get("errorMessage", "") - raise MpesaApiException( - MpesaError( - error_code=f"HTTP_{response.status_code}", - error_message=error_message, - status_code=response.status_code, - raw_response=response_data, - ) - ) + Args: + url (str): The URL path for the request. + json (Dict[str, Any]): The JSON payload for the request body. + headers (Dict[str, str]): The HTTP headers for the request. + timeout (int): The timeout for the request in seconds. + idempotent (bool): Add an idempotency key for safe retries. - return response_data + Returns: + Dict[str, Any]: The JSON response from the API. + """ + h= httpx.Headers( headers or {}) + + if idempotent and "X-Idempotency-Key" not in h: + h["X-Idempotency-Key"] = str(uuid.uuid4()) + + response: httpx.Response | None = None + + try: + if idempotent: + response = await self._retryable_post(url, json, h, timeout) + else: + response = await self._raw_post(url, json, h, timeout) + handle_request_error(response) + return response.json() + except httpx.ConnectTimeout as e: + raise MpesaApiException( + MpesaError( + error_code="REQUEST_TIMEOUT", + error_message=str(e) + ) + ) from e - except httpx.TimeoutException: + except httpx.TimeoutException as e: raise MpesaApiException( MpesaError( error_code="REQUEST_TIMEOUT", - error_message="Request to Mpesa timed out.", - status_code=None, + error_message=str(e) ) - ) - except httpx.ConnectError: + ) from e + except httpx.ConnectError as e: raise MpesaApiException( MpesaError( error_code="CONNECTION_ERROR", - error_message="Failed to connect to Mpesa API. Check network or URL.", - status_code=None, + error_message=str(e) ) - ) - except httpx.HTTPError as e: + ) from e + except (httpx.RequestError, ValueError) as e: raise MpesaApiException( MpesaError( error_code="REQUEST_FAILED", - error_message=f"HTTP request failed: {str(e)}", - status_code=None, - raw_response=None, + error_message=str(e), + status_code=getattr(response, "status_code", None), + raw_response=getattr(response, "text", None), ) - ) + ) from e + + + @retry( + retry=retry_enabled(enabled=True), + wait=wait_random_exponential(multiplier=5, max=8), + stop=stop_after_attempt(3), + retry_error_callback=handle_retry_exception, + before_sleep=before_sleep_log(logger, logging.WARNING), + reraise=True, + ) + async def _raw_get( + self, + url: str, + params: Optional[Dict[str, Any]] = None, + headers: Optional[Dict[str, str]] = None, + timeout: int = 10, + ) -> httpx.Response: + """Low-level GET request - may raise httpx exceptions.""" + if headers is None: + headers = {} + + return await self._client.get( + url, + params=params, + headers=headers, + timeout=timeout, + ) async def get( self, url: str, params: Optional[Dict[str, Any]] = None, headers: Optional[Dict[str, str]] = None, + timeout: int = 10, ) -> Dict[str, Any]: - """Sends an asynchronous GET request to the M-Pesa API.""" - try: - if headers is None: - headers = {} - - response = await self._client.get( - url, params=params, headers=headers, timeout=10 - ) - - try: - response_data = response.json() - except ValueError: - response_data = {"errorMessage": response.text.strip() or ""} - - if not response.is_success: - error_message = response_data.get("errorMessage", "") - raise MpesaApiException( - MpesaError( - error_code=f"HTTP_{response.status_code}", - error_message=error_message, - status_code=response.status_code, - raw_response=response_data, - ) - ) + """Sends a GET request to the M-Pesa API. - return response_data + Args: + url (str): The URL path for the request. + params (Optional[Dict[str, Any]]): The URL parameters. + headers (Optional[Dict[str, str]]): The HTTP headers. + timeout (int): The timeout for the request in seconds. + Returns: + Dict[str, Any]: The JSON response from the API. + """ + response: httpx.Response | None = None - except httpx.TimeoutException: - raise MpesaApiException( - MpesaError( - error_code="REQUEST_TIMEOUT", - error_message="Request to Mpesa timed out.", - status_code=None, - ) - ) - except httpx.ConnectError: - raise MpesaApiException( - MpesaError( - error_code="CONNECTION_ERROR", - error_message="Failed to connect to Mpesa API. Check network or URL.", - status_code=None, - ) - ) - except httpx.HTTPError as e: + try: + response = await self._raw_get(url, params, headers, timeout) + handle_request_error(response) + return response.json() + + except (httpx.RequestError, ValueError) as e: raise MpesaApiException( MpesaError( error_code="REQUEST_FAILED", - error_message=f"HTTP request failed: {str(e)}", - status_code=None, - raw_response=None, + error_message=str(e), + status_code=getattr(response, "status_code", None), + raw_response=getattr(response, "text", None), ) - ) + ) from e async def aclose(self): """Manually close the underlying httpx client connection pool.""" diff --git a/mpesakit/http_client/mpesa_http_client.py b/mpesakit/http_client/mpesa_http_client.py index fd0a503..79e2c54 100644 --- a/mpesakit/http_client/mpesa_http_client.py +++ b/mpesakit/http_client/mpesa_http_client.py @@ -9,93 +9,19 @@ import httpx from tenacity import ( - RetryCallState, before_sleep_log, retry, - retry_if_exception_type, stop_after_attempt, wait_random_exponential, ) from mpesakit.errors import MpesaApiException, MpesaError -from .http_client import HttpClient +from .http_client import HttpClient,handle_request_error,handle_retry_exception,retry_enabled logger = logging.getLogger(__name__) -def handle_request_error(response: httpx.Response): - """Handles non-successful HTTP responses. - - This function is now responsible for converting HTTP status codes - and JSON parsing errors into MpesaApiException. - """ - if response.is_success: - return - try: - response_data = response.json() - except ValueError: - response_data = {"errorMessage": response.text.strip() or ""} - - error_message = response_data.get("errorMessage", "") - raise MpesaApiException( - MpesaError( - error_code=f"HTTP_{response.status_code}", - error_message=error_message, - status_code=response.status_code, - raw_response=response_data, - ) - ) - - -def handle_retry_exception(retry_state: RetryCallState): - """Custom hook to handle exceptions after all retries fail. - - It raises a custom MpesaApiException with the appropriate error code. - """ - if retry_state.outcome: - exception = retry_state.outcome.exception() - - if isinstance(exception, httpx.TimeoutException): - raise MpesaApiException( - MpesaError(error_code="REQUEST_TIMEOUT", error_message=str(exception)) - ) from exception - elif isinstance(exception, httpx.ConnectError): - raise MpesaApiException( - MpesaError(error_code="CONNECTION_ERROR", error_message=str(exception)) - ) from exception - - raise MpesaApiException( - MpesaError(error_code="REQUEST_FAILED", error_message=str(exception)) - ) from exception - - raise MpesaApiException( - MpesaError( - error_code="REQUEST_FAILED", - error_message="An unknown retry error occurred.", - ) - ) - - -def retry_enabled(enabled: bool): - """Factory function to conditionally enable retries. - - Args: - enabled (bool): Whether to enable retry logic. - - Returns: - A retry condition function. - """ - base_retry = retry_if_exception_type( - httpx.TimeoutException - ) | retry_if_exception_type(httpx.ConnectError) - - def _retry(retry_state): - if not enabled: - return False - return base_retry(retry_state) - - return _retry class MpesaHttpClient(HttpClient): diff --git a/mpesakit/mpesa_express/schemas.py b/mpesakit/mpesa_express/schemas.py index 96f370b..78876be 100644 --- a/mpesakit/mpesa_express/schemas.py +++ b/mpesakit/mpesa_express/schemas.py @@ -89,7 +89,7 @@ class StkPushSimulateRequest(BaseModel): json_schema_extra={ "example": { "BusinessShortCode": 654321, - "Password": "bXlwYXNzd29yZA==", + "Password": "", "Timestamp": "20240607123045", "TransactionType": "CustomerPayBillOnline", "Amount": 10, @@ -499,7 +499,7 @@ class StkPushQueryRequest(BaseModel): json_schema_extra={ "example": { "BusinessShortCode": 654321, - "Password": "bXlwYXNzd29yZA==", + "Password": "", "Timestamp": "20240607123045", "CheckoutRequestID": "ws_CO_DMZ_123212312_2342347678234", } diff --git a/tests/unit/http_client/test_mpesa_async_http_client.py b/tests/unit/http_client/test_mpesa_async_http_client.py index 4170f80..87c7e61 100644 --- a/tests/unit/http_client/test_mpesa_async_http_client.py +++ b/tests/unit/http_client/test_mpesa_async_http_client.py @@ -48,7 +48,10 @@ async def test_post_success(async_client): assert result == {"foo": "bar"} mock_post.assert_called_once() - mock_post.assert_called_with("/test", json={"a": 1}, headers={"h": "v"}, timeout=10) + mock_post.assert_called_with("/test", + json={"a": 1}, + headers={"h": "v"}, + timeout=10) @pytest.mark.asyncio @@ -128,7 +131,58 @@ async def test_post_generic_httpx_error(async_client): assert exc.value.error.error_code == "REQUEST_FAILED" assert "protocol error" in exc.value.error.error_message +@pytest.mark.asyncio +async def test_post_idempotency_key_persistence(async_client): + """Test that the same idempotency key is reused across retries.""" + mock_response = Mock() + mock_response.status_code = 200 + mock_response.is_success = True + mock_response.json.return_value = {"Result": "Success"} + + with patch.object( + async_client._client, + "post", + side_effect=[httpx.ConnectError("Fail 1"), httpx.ConnectError("Fail 2"),mock_response], + new_callable=AsyncMock, + return_value=mock_response + ) as mock_post: + + await async_client.post("/stkpush",json={"Amount":1},headers={},idempotent=True) + assert mock_post.call_count == 3 + + first_call_headers = mock_post.call_args_list[0].kwargs['headers'] + third_call_headers = mock_post.call_args_list[2].kwargs['headers'] + + assert "X-Idempotency-Key" in first_call_headers + assert first_call_headers["X-Idempotency-Key"] == third_call_headers["X-Idempotency-Key"] + +@pytest.mark.asyncio +async def test_post_does_not_retry_on_400(async_client): + """Test that 400 errors fail immediately without retrying.""" + mock_response = Mock() + mock_response.status_code = 400 + mock_response.is_success = False + mock_response.json.return_value={"errorMessage":"Invalid MSISDN"} + + with patch.object(async_client._client,"post",new_callable=AsyncMock,return_value=mock_response) as mock_post: + with pytest.raises(MpesaApiException): + await async_client.post("/stkpush",json={"bad":"data"}) + + assert mock_post.call_count == 1 + +@pytest.mark.asyncio +async def test_post_retries_until_stop_limit(async_client): + """Verify that the client attempts 3 times for connection errors.""" + with patch.object( + async_client._client, + "post", + side_effect=httpx.ConnectError("Dead pipe") + )as mock_post: + with pytest.raises(MpesaApiException): + await async_client.post("/any",json={},idempotent=True) + + assert mock_post.call_count == 3 @pytest.mark.asyncio async def test_get_success(async_client): """Test successful ASYNC GET request returns expected JSON.""" @@ -171,7 +225,7 @@ async def test_get_timeout(async_client): await async_client.get("/timeout") assert exc.value.error.error_code == "REQUEST_TIMEOUT" - assert "timed out" in exc.value.error.error_message + assert "Test Timeout" in exc.value.error.error_message @pytest.mark.asyncio @@ -187,7 +241,7 @@ async def test_get_connection_error(async_client): await async_client.get("/conn") assert exc.value.error.error_code == "CONNECTION_ERROR" - assert "Failed to connect" in exc.value.error.error_message + assert "conn error" in exc.value.error.error_message @pytest.mark.asyncio @@ -203,4 +257,4 @@ async def test_get_generic_httpx_error(async_client): await async_client.get("/error") assert exc.value.error.error_code == "REQUEST_FAILED" - assert "HTTP request failed" in exc.value.error.error_message + assert "protocol error" in exc.value.error.error_message diff --git a/tests/unit/mpesa_express/test_stk_push.py b/tests/unit/mpesa_express/test_stk_push.py index 26248f8..99aa16a 100644 --- a/tests/unit/mpesa_express/test_stk_push.py +++ b/tests/unit/mpesa_express/test_stk_push.py @@ -3,12 +3,13 @@ This module tests the StkPush class for initiating and querying M-Pesa STK Push transactions. """ -from unittest.mock import AsyncMock, MagicMock +from unittest.mock import AsyncMock, MagicMock,patch +import httpx import pytest from mpesakit.auth import AsyncTokenManager, TokenManager -from mpesakit.http_client import AsyncHttpClient, HttpClient +from mpesakit.http_client import AsyncHttpClient, HttpClient, MpesaHttpClient, MpesaAsyncHttpClient from mpesakit.mpesa_express.stk_push import ( AsyncStkPush, StkPush, @@ -160,6 +161,52 @@ def test_stk_push_simulate_request_invalid_transaction_type(): StkPushSimulateRequest(**valid_kwargs) assert "TransactionType must be one of:" in str(excinfo.value) +@pytest.mark.parametrize("use_session", [True, False]) +def test_stk_push_multiple_times(use_session, mock_token_manager): + """Test that StkPush service layer works correctly over multiple iterations.""" + with MpesaHttpClient(env="sandbox", use_session=use_session) as client: + stk = StkPush(http_client=client,token_manager=mock_token_manager) + + request_data = StkPushSimulateRequest( + BusinessShortCode=174379, + Amount=1, + PhoneNumber="254700000000", + CallBackURL="https://example.com/callback", + AccountReference="TestRef", + TransactionDesc="TestDesc", + TransactionType="CustomerPayBillOnline", + PartyA="254700000000", + PartyB="174379", + Timestamp="20231010120000", + Password="base64_encoded_password" + ) + + mock_response = MagicMock(spec=httpx.Response) + mock_response.is_success = True + mock_response.status_code = 200 + mock_response.json.return_value = { + "MerchantRequestID": "12345", + "ResponseCode": "0", + "CustomerMessage": "Success", + "CheckoutRequestID":"ws_CO_260520211133524545", + "ResponseDescription":"Test Description" + } + + + with patch.object(mock_token_manager, "get_token", return_value="mock_token"): + with patch.object(httpx.Client, "send", return_value=mock_response) as mock_send: + success_count = 0 + + for _ in range(100): + result = stk.push( + request=request_data + + ) + if str(result.ResponseCode) == "0": + success_count += 1 + + assert success_count == 100 + assert mock_send.call_count == 100 @pytest.fixture def mock_async_token_manager(): @@ -290,3 +337,50 @@ async def test_async_query_handles_http_error(async_stk_push, mock_async_http_cl with pytest.raises(Exception) as excinfo: await async_stk_push.query(request) assert "HTTP error" in str(excinfo.value) + +@pytest.mark.asyncio +async def test_async_stk_push_multiple_times(mock_async_token_manager): + """Test that StkPush can be used multiple times.""" + async with MpesaAsyncHttpClient() as client: + stk = AsyncStkPush(http_client =client,token_manager=mock_async_token_manager) + + request_data = StkPushSimulateRequest( + BusinessShortCode=174379, + Amount=1, + PhoneNumber="254700000000", + CallBackURL="https://example.com/callback", + AccountReference="TestRef", + TransactionDesc="TestDesc", + TransactionType="CustomerPayBillOnline", + PartyA="254700000000", + PartyB="174379", + Timestamp="20231010120000", + Password="base64_encoded_password" + ) + + + mock_response = MagicMock(spec=httpx.Response) + mock_response.is_success = True + mock_response.status_code = 200 + mock_response.json.return_value = { + "MerchantRequestID": "12345", + "CheckoutRequestID":"ws_CO_260520211133524545", + "ResponseDescription":"Test Description", + "CustomerMessage": "Success", + "ResponseCode": "0"} + + with patch.object(mock_async_token_manager, "get_token", AsyncMock(return_value="mock_token")): + with patch.object(httpx.AsyncClient, "send", AsyncMock(return_value=mock_response)) as mock_send: + success_count = 0 + + for _ in range(100): + result = await stk.push( + request=request_data + ) + if str(result.ResponseCode) == "0": + success_count += 1 + + assert success_count == 100 + assert mock_send.call_count == 100 + + await client.aclose()