diff --git a/winrm/protocol.py b/winrm/protocol.py index a93d7fa..a2cfafe 100644 --- a/winrm/protocol.py +++ b/winrm/protocol.py @@ -32,10 +32,12 @@ class Protocol(object): first. """ - DEFAULT_READ_TIMEOUT_SEC = 30 - DEFAULT_OPERATION_TIMEOUT_SEC = 20 - DEFAULT_MAX_ENV_SIZE = 153600 DEFAULT_LOCALE = "en-US" + DEFAULT_MAX_ENV_SIZE = 153600 + DEFAULT_OPERATION_TIMEOUT_SEC = 20 + DEFAULT_READ_TIMEOUT_SEC = 30 + DEFAULT_RECONNECTION_BACKOFF = 2.0 + DEFAULT_RECONNECTION_RETRIES = 0 def __init__( self, @@ -53,6 +55,8 @@ def __init__( kerberos_delegation: bool = False, read_timeout_sec: str | int = DEFAULT_READ_TIMEOUT_SEC, operation_timeout_sec: str | int = DEFAULT_OPERATION_TIMEOUT_SEC, + reconnection_retries: str | int = DEFAULT_RECONNECTION_RETRIES, + reconnection_backoff: str | float = DEFAULT_RECONNECTION_BACKOFF, kerberos_hostname_override: str | None = None, message_encryption: t.Literal["auto", "always", "never"] = "auto", credssp_disable_tlsv1_2: bool = False, @@ -77,6 +81,8 @@ def __init__( @param bool kerberos_delegation: if True, TGT is sent to target server to allow multiple hops # NOQA @param int read_timeout_sec: maximum seconds to wait before an HTTP connect/read times out (default 30). This value should be slightly higher than operation_timeout_sec, as the server can block *at least* that long. # NOQA @param int operation_timeout_sec: maximum allowed time in seconds for any single wsman HTTP operation (default 20). Note that operation timeouts while receiving output (the only wsman operation that should take any significant time, and where these timeouts are expected) will be silently retried indefinitely. # NOQA + @param int reconnection_retries: Number of retries on connection problems + @param float reconnection_backoff: Number of seconds to backoff in between reconnection attempts (first sleeps X, then sleeps 2*X, then sleeps 4*X, ...) @param string kerberos_hostname_override: the hostname to use for the kerberos exchange (defaults to the hostname in the endpoint URL) @param bool message_encryption_enabled: Will encrypt the WinRM messages if set to True and the transport auth supports message encryption (Default True). @param string proxy: Specify a proxy for the WinRM connection to use. 'legacy_requests'(default) to use environment variables, None to disable proxies completely or the proxy URL itself. @@ -95,6 +101,16 @@ def __init__( if operation_timeout_sec >= read_timeout_sec or operation_timeout_sec < 1: raise WinRMError("read_timeout_sec must exceed operation_timeout_sec, and both must be non-zero") + try: + reconnection_retries = int(reconnection_retries) + except ValueError as ve: + raise ValueError("failed to parse reconnection_retries as int: %s" % str(ve)) + + try: + reconnection_backoff = float(reconnection_backoff) + except ValueError as ve: + raise ValueError("failed to parse reconnection_backoff as float: %s" % str(ve)) + self.read_timeout_sec = read_timeout_sec self.operation_timeout_sec = operation_timeout_sec self.max_env_sz = Protocol.DEFAULT_MAX_ENV_SIZE @@ -111,6 +127,8 @@ def __init__( cert_pem=cert_pem, cert_key_pem=cert_key_pem, read_timeout_sec=self.read_timeout_sec, + reconnection_retries=reconnection_retries, + reconnection_backoff=reconnection_backoff, server_cert_validation=server_cert_validation, kerberos_delegation=kerberos_delegation, kerberos_hostname_override=kerberos_hostname_override, @@ -130,6 +148,8 @@ def __init__( self.kerberos_delegation = kerberos_delegation self.kerberos_hostname_override = kerberos_hostname_override self.credssp_disable_tlsv1_2 = credssp_disable_tlsv1_2 + self.reconnection_retries = reconnection_retries + self.reconnection_backoff = reconnection_backoff def open_shell( self, diff --git a/winrm/tests/test_protocol.py b/winrm/tests/test_protocol.py index d7835bc..e885fc0 100644 --- a/winrm/tests/test_protocol.py +++ b/winrm/tests/test_protocol.py @@ -86,6 +86,12 @@ def test_set_timeout_as_sec(): assert protocol.operation_timeout_sec == 29 +def test_set_retry_connection(): + protocol = Protocol("endpoint", username="username", password="password", reconnection_retries="5", reconnection_backoff="3") + assert protocol.reconnection_retries == 5 + assert protocol.reconnection_backoff == 3.0 + + def test_fail_set_read_timeout_as_sec(): with pytest.raises(ValueError) as exc: Protocol("endpoint", username="username", password="password", read_timeout_sec="30a", operation_timeout_sec="29") @@ -96,3 +102,15 @@ def test_fail_set_operation_timeout_as_sec(): with pytest.raises(ValueError) as exc: Protocol("endpoint", username="username", password="password", read_timeout_sec=30, operation_timeout_sec="29a") assert str(exc.value) == "failed to parse operation_timeout_sec as int: " "invalid literal for int() with base 10: '29a'" + + +def test_fail_set_reconnection_retries(): + with pytest.raises(ValueError) as exc: + Protocol("endpoint", username="username", password="password", reconnection_retries="5a", reconnection_backoff=4.0) + assert str(exc.value) == "failed to parse reconnection_retries as int: " "invalid literal for int() with base 10: '5a'" + + +def test_fail_set_reconnection_backoff(): + with pytest.raises(ValueError) as exc: + Protocol("endpoint", username="username", password="password", reconnection_retries=5, reconnection_backoff="4a") + assert str(exc.value) == "failed to parse reconnection_backoff as float: " "could not convert string to float: '4a'" diff --git a/winrm/transport.py b/winrm/transport.py index fa2e781..ca46452 100644 --- a/winrm/transport.py +++ b/winrm/transport.py @@ -6,6 +6,7 @@ import requests import requests.auth +import urllib3 from winrm.encryption import Encryption from winrm.exceptions import InvalidCredentialsError, WinRMError, WinRMTransportError @@ -70,6 +71,8 @@ def __init__( cert_pem: str | None = None, cert_key_pem: str | None = None, read_timeout_sec: int | None = None, + reconnection_retries: int | None = 0, + reconnection_backoff: float = 2.0, server_cert_validation: t.Literal["validate", "ignore"] | None = "validate", kerberos_delegation: bool | str = False, kerberos_hostname_override: str | None = None, @@ -91,6 +94,8 @@ def __init__( self.cert_pem = cert_pem self.cert_key_pem = cert_key_pem self.read_timeout_sec = read_timeout_sec + self.reconnection_retries = reconnection_retries + self.reconnection_backoff = reconnection_backoff self.server_cert_validation = server_cert_validation self.kerberos_hostname_override = kerberos_hostname_override self.message_encryption = message_encryption @@ -186,6 +191,20 @@ def build_session(self) -> requests.Session: # Merge proxy environment variables settings = session.merge_environment_settings(url=self.endpoint, proxies=proxies, stream=None, verify=None, cert=None) + # Retry on connection errors, with a backoff factor + retries = urllib3.util.retry.Retry( + total=self.reconnection_retries, + connect=self.reconnection_retries, + read=0, + redirect=0, + status=self.reconnection_retries, + other=0, + status_forcelist=(425, 429, 503), + backoff_factor=self.reconnection_backoff, + ) + session.mount("http://", requests.adapters.HTTPAdapter(max_retries=retries)) + session.mount("https://", requests.adapters.HTTPAdapter(max_retries=retries)) + global DISPLAYED_PROXY_WARNING # We want to eventually stop reading proxy information from the environment.