diff --git a/src/Renci.SshNet/Security/IKeyExchange.cs b/src/Renci.SshNet/Security/IKeyExchange.cs
index f12a18322..7ffd2f465 100644
--- a/src/Renci.SshNet/Security/IKeyExchange.cs
+++ b/src/Renci.SshNet/Security/IKeyExchange.cs
@@ -38,8 +38,9 @@ public interface IKeyExchange : IDisposable
/// Starts the key exchange algorithm.
///
/// The session.
- /// Key exchange init message.
- void Start(Session session, KeyExchangeInitMessage message);
+ /// The key exchange init message received from the server.
+ /// Whether to send a key exchange init message in response.
+ void Start(Session session, KeyExchangeInitMessage message, bool sendClientInitMessage);
///
/// Finishes the key exchange algorithm.
diff --git a/src/Renci.SshNet/Security/KeyExchange.cs b/src/Renci.SshNet/Security/KeyExchange.cs
index 44684a92e..f01a4b117 100644
--- a/src/Renci.SshNet/Security/KeyExchange.cs
+++ b/src/Renci.SshNet/Security/KeyExchange.cs
@@ -61,16 +61,15 @@ public byte[] ExchangeHash
///
public event EventHandler HostKeyReceived;
- ///
- /// Starts key exchange algorithm.
- ///
- /// The session.
- /// Key exchange init message.
- public virtual void Start(Session session, KeyExchangeInitMessage message)
+ ///
+ public virtual void Start(Session session, KeyExchangeInitMessage message, bool sendClientInitMessage)
{
Session = session;
- SendMessage(session.ClientInitMessage);
+ if (sendClientInitMessage)
+ {
+ SendMessage(session.ClientInitMessage);
+ }
// Determine encryption algorithm
var clientEncryptionAlgorithmName = (from b in session.ConnectionInfo.Encryptions.Keys
diff --git a/src/Renci.SshNet/Security/KeyExchangeDiffieHellman.cs b/src/Renci.SshNet/Security/KeyExchangeDiffieHellman.cs
index 4f31514a7..7dfc51e34 100644
--- a/src/Renci.SshNet/Security/KeyExchangeDiffieHellman.cs
+++ b/src/Renci.SshNet/Security/KeyExchangeDiffieHellman.cs
@@ -76,14 +76,10 @@ protected override bool ValidateExchangeHash()
return ValidateExchangeHash(_hostKey, _signature);
}
- ///
- /// Starts key exchange algorithm.
- ///
- /// The session.
- /// Key exchange init message.
- public override void Start(Session session, KeyExchangeInitMessage message)
+ ///
+ public override void Start(Session session, KeyExchangeInitMessage message, bool sendClientInitMessage)
{
- base.Start(session, message);
+ base.Start(session, message, sendClientInitMessage);
_serverPayload = message.GetBytes();
_clientPayload = Session.ClientInitMessage.GetBytes();
diff --git a/src/Renci.SshNet/Security/KeyExchangeDiffieHellmanGroupExchangeShaBase.cs b/src/Renci.SshNet/Security/KeyExchangeDiffieHellmanGroupExchangeShaBase.cs
index 93703ee8f..5774f2c34 100644
--- a/src/Renci.SshNet/Security/KeyExchangeDiffieHellmanGroupExchangeShaBase.cs
+++ b/src/Renci.SshNet/Security/KeyExchangeDiffieHellmanGroupExchangeShaBase.cs
@@ -39,14 +39,10 @@ protected override byte[] CalculateHash()
return Hash(groupExchangeHashData.GetBytes());
}
- ///
- /// Starts key exchange algorithm.
- ///
- /// The session.
- /// Key exchange init message.
- public override void Start(Session session, KeyExchangeInitMessage message)
+ ///
+ public override void Start(Session session, KeyExchangeInitMessage message, bool sendClientInitMessage)
{
- base.Start(session, message);
+ base.Start(session, message, sendClientInitMessage);
// Register SSH_MSG_KEX_DH_GEX_GROUP message
Session.RegisterMessage("SSH_MSG_KEX_DH_GEX_GROUP");
diff --git a/src/Renci.SshNet/Security/KeyExchangeDiffieHellmanGroupShaBase.cs b/src/Renci.SshNet/Security/KeyExchangeDiffieHellmanGroupShaBase.cs
index 63c2bba40..b0db30eaa 100644
--- a/src/Renci.SshNet/Security/KeyExchangeDiffieHellmanGroupShaBase.cs
+++ b/src/Renci.SshNet/Security/KeyExchangeDiffieHellmanGroupShaBase.cs
@@ -13,14 +13,10 @@ internal abstract class KeyExchangeDiffieHellmanGroupShaBase : KeyExchangeDiffie
///
public abstract BigInteger GroupPrime { get; }
- ///
- /// Starts key exchange algorithm.
- ///
- /// The session.
- /// Key exchange init message.
- public override void Start(Session session, KeyExchangeInitMessage message)
+ ///
+ public override void Start(Session session, KeyExchangeInitMessage message, bool sendClientInitMessage)
{
- base.Start(session, message);
+ base.Start(session, message, sendClientInitMessage);
Session.RegisterMessage("SSH_MSG_KEXDH_REPLY");
diff --git a/src/Renci.SshNet/Security/KeyExchangeEC.cs b/src/Renci.SshNet/Security/KeyExchangeEC.cs
index 4368affbf..8bc61e7fc 100644
--- a/src/Renci.SshNet/Security/KeyExchangeEC.cs
+++ b/src/Renci.SshNet/Security/KeyExchangeEC.cs
@@ -78,14 +78,10 @@ protected override bool ValidateExchangeHash()
return ValidateExchangeHash(_hostKey, _signature);
}
- ///
- /// Starts key exchange algorithm.
- ///
- /// The session.
- /// Key exchange init message.
- public override void Start(Session session, KeyExchangeInitMessage message)
+ ///
+ public override void Start(Session session, KeyExchangeInitMessage message, bool sendClientInitMessage)
{
- base.Start(session, message);
+ base.Start(session, message, sendClientInitMessage);
_serverPayload = message.GetBytes();
_clientPayload = Session.ClientInitMessage.GetBytes();
diff --git a/src/Renci.SshNet/Security/KeyExchangeECCurve25519.cs b/src/Renci.SshNet/Security/KeyExchangeECCurve25519.cs
index 18443fe73..c6c060bab 100644
--- a/src/Renci.SshNet/Security/KeyExchangeECCurve25519.cs
+++ b/src/Renci.SshNet/Security/KeyExchangeECCurve25519.cs
@@ -29,14 +29,10 @@ protected override int HashSize
get { return 256; }
}
- ///
- /// Starts key exchange algorithm.
- ///
- /// The session.
- /// Key exchange init message.
- public override void Start(Session session, KeyExchangeInitMessage message)
+ ///
+ public override void Start(Session session, KeyExchangeInitMessage message, bool sendClientInitMessage)
{
- base.Start(session, message);
+ base.Start(session, message, sendClientInitMessage);
Session.RegisterMessage("SSH_MSG_KEX_ECDH_REPLY");
diff --git a/src/Renci.SshNet/Security/KeyExchangeECDH.cs b/src/Renci.SshNet/Security/KeyExchangeECDH.cs
index c3fc7bfe4..c756fb6cb 100644
--- a/src/Renci.SshNet/Security/KeyExchangeECDH.cs
+++ b/src/Renci.SshNet/Security/KeyExchangeECDH.cs
@@ -24,14 +24,10 @@ internal abstract class KeyExchangeECDH : KeyExchangeEC
private ECDHCBasicAgreement _keyAgreement;
private ECDomainParameters _domainParameters;
- ///
- /// Starts key exchange algorithm.
- ///
- /// The session.
- /// Key exchange init message.
- public override void Start(Session session, KeyExchangeInitMessage message)
+ ///
+ public override void Start(Session session, KeyExchangeInitMessage message, bool sendClientInitMessage)
{
- base.Start(session, message);
+ base.Start(session, message, sendClientInitMessage);
Session.RegisterMessage("SSH_MSG_KEX_ECDH_REPLY");
diff --git a/src/Renci.SshNet/Session.cs b/src/Renci.SshNet/Session.cs
index 5bf6d8eef..a57d5a1e7 100644
--- a/src/Renci.SshNet/Session.cs
+++ b/src/Renci.SshNet/Session.cs
@@ -160,12 +160,7 @@ public class Session : ISession
///
/// WaitHandle to signal that key exchange was completed.
///
- private EventWaitHandle _keyExchangeCompletedWaitHandle = new ManualResetEvent(initialState: false);
-
- ///
- /// WaitHandle to signal that key exchange is in progress.
- ///
- private bool _keyExchangeInProgress;
+ private ManualResetEventSlim _keyExchangeCompletedWaitHandle = new ManualResetEventSlim(initialState: false);
///
/// Exception that need to be thrown by waiting thread.
@@ -643,6 +638,11 @@ public void Connect()
// Some server implementations might sent this message first, prior to establishing encryption algorithm
RegisterMessage("SSH_MSG_USERAUTH_BANNER");
+ // Send our key exchange init.
+ // We need to do this before starting the message listener to avoid the case where we receive the server
+ // key exchange init and we continue the key exchange before having sent our own init.
+ SendMessage(ClientInitMessage);
+
// Mark the message listener threads as started
_ = _messageListenerCompleted.Reset();
@@ -651,7 +651,7 @@ public void Connect()
_ = ThreadAbstraction.ExecuteThreadLongRunning(MessageListener);
// Wait for key exchange to be completed
- WaitOnHandle(_keyExchangeCompletedWaitHandle);
+ WaitOnHandle(_keyExchangeCompletedWaitHandle.WaitHandle);
// If sessionId is not set then its not connected
if (SessionId is null)
@@ -757,6 +757,11 @@ public async Task ConnectAsync(CancellationToken cancellationToken)
// Some server implementations might sent this message first, prior to establishing encryption algorithm
RegisterMessage("SSH_MSG_USERAUTH_BANNER");
+ // Send our key exchange init.
+ // We need to do this before starting the message listener to avoid the case where we receive the server
+ // key exchange init and we continue the key exchange before having sent our own init.
+ SendMessage(ClientInitMessage);
+
// Mark the message listener threads as started
_ = _messageListenerCompleted.Reset();
@@ -765,7 +770,7 @@ public async Task ConnectAsync(CancellationToken cancellationToken)
_ = ThreadAbstraction.ExecuteThreadLongRunning(MessageListener);
// Wait for key exchange to be completed
- WaitOnHandle(_keyExchangeCompletedWaitHandle);
+ WaitOnHandle(_keyExchangeCompletedWaitHandle.WaitHandle);
// If sessionId is not set then its not connected
if (SessionId is null)
@@ -1046,10 +1051,10 @@ internal void SendMessage(Message message)
throw new SshConnectionException("Client not connected.");
}
- if (_keyExchangeInProgress && message is not IKeyExchangedAllowed)
+ if (!_keyExchangeCompletedWaitHandle.IsSet && message is not IKeyExchangedAllowed)
{
// Wait for key exchange to be completed
- WaitOnHandle(_keyExchangeCompletedWaitHandle);
+ WaitOnHandle(_keyExchangeCompletedWaitHandle.WaitHandle);
}
DiagnosticAbstraction.Log(string.Format("[{0}] Sending message '{1}' to server: '{2}'.", ToHex(SessionId), message.GetType().Name, message));
@@ -1394,9 +1399,15 @@ internal void OnKeyExchangeDhGroupExchangeReplyReceived(KeyExchangeDhGroupExchan
/// message.
internal void OnKeyExchangeInitReceived(KeyExchangeInitMessage message)
{
- _keyExchangeInProgress = true;
+ // If _keyExchangeCompletedWaitHandle is already set, then this is a key
+ // re-exchange initiated by the server, and we need to send our own init
+ // message.
+ // Otherwise, the wait handle is not set and this received init is part of the
+ // initial connection for which we have already sent our init, so we shouldn't
+ // send another one.
+ var sendClientInitMessage = _keyExchangeCompletedWaitHandle.IsSet;
- _ = _keyExchangeCompletedWaitHandle.Reset();
+ _keyExchangeCompletedWaitHandle.Reset();
// Disable messages that are not key exchange related
_sshMessageFactory.DisableNonKeyExchangeMessages();
@@ -1411,7 +1422,7 @@ internal void OnKeyExchangeInitReceived(KeyExchangeInitMessage message)
_keyExchange.HostKeyReceived += KeyExchange_HostKeyReceived;
// Start the algorithm implementation
- _keyExchange.Start(this, message);
+ _keyExchange.Start(this, message, sendClientInitMessage);
KeyExchangeInitReceived?.Invoke(this, new MessageEventArgs(message));
}
@@ -1477,9 +1488,7 @@ internal void OnNewKeysReceived(NewKeysMessage message)
NewKeysReceived?.Invoke(this, new MessageEventArgs(message));
// Signal that key exchange completed
- _ = _keyExchangeCompletedWaitHandle.Set();
-
- _keyExchangeInProgress = false;
+ _keyExchangeCompletedWaitHandle.Set();
}
///
@@ -1967,7 +1976,7 @@ private void RaiseError(Exception exp)
private void Reset()
{
_ = _exceptionWaitHandle?.Reset();
- _ = _keyExchangeCompletedWaitHandle?.Reset();
+ _keyExchangeCompletedWaitHandle?.Reset();
_ = _messageListenerCompleted?.Set();
SessionId = null;
@@ -1975,7 +1984,6 @@ private void Reset()
_isDisconnecting = false;
_isAuthenticated = false;
_exception = null;
- _keyExchangeInProgress = false;
}
private static SshConnectionException CreateConnectionAbortedByServerException()
diff --git a/test/Renci.SshNet.Tests/Classes/SessionTest_ConnectedBase.cs b/test/Renci.SshNet.Tests/Classes/SessionTest_ConnectedBase.cs
index 7fa1ac24e..6331f7b9c 100644
--- a/test/Renci.SshNet.Tests/Classes/SessionTest_ConnectedBase.cs
+++ b/test/Renci.SshNet.Tests/Classes/SessionTest_ConnectedBase.cs
@@ -49,6 +49,12 @@ public abstract class SessionTest_ConnectedBase
internal SshIdentification ServerIdentification { get; set; }
protected bool CallSessionConnectWhenArrange { get; set; }
+ ///
+ /// Should the "server" wait for the client kexinit before sending its own.
+ /// A regression test simulating e.g. cisco devices.
+ ///
+ protected bool WaitForClientKeyExchangeInit { get; set; }
+
[TestInitialize]
public void Setup()
{
@@ -59,18 +65,18 @@ public void Setup()
[TestCleanup]
public void TearDown()
{
- if (ServerSocket != null)
- {
- ServerSocket.Dispose();
- ServerSocket = null;
- }
-
if (ServerListener != null)
{
ServerListener.Dispose();
ServerListener = null;
}
+ if (ServerSocket != null)
+ {
+ ServerSocket.Dispose();
+ ServerSocket = null;
+ }
+
if (Session != null)
{
Session.Dispose();
@@ -115,6 +121,15 @@ protected virtual void SetupData()
var newKeysMessage = new NewKeysMessage();
var newKeys = newKeysMessage.GetPacket(8, null);
_ = ServerSocket.Send(newKeys, 4, newKeys.Length - 4, SocketFlags.None);
+
+ if (!_authenticationStarted)
+ {
+ var serviceAcceptMessage = ServiceAcceptMessageBuilder.Create(ServiceName.UserAuthentication)
+ .Build();
+ _ = ServerSocket.Send(serviceAcceptMessage, 0, serviceAcceptMessage.Length, SocketFlags.None);
+
+ _authenticationStarted = true;
+ }
};
ServerListener = new AsyncSocketListener(_serverEndPoint)
@@ -125,36 +140,23 @@ protected virtual void SetupData()
{
ServerSocket = socket;
- // Since we're mocking the protocol version exchange, we'll immediately stat KEX upon
+ // Since we're mocking the protocol version exchange, we'll immediately start KEX upon
// having established the connection instead of when the client has been identified
- var keyExchangeInitMessage = new KeyExchangeInitMessage
- {
- CompressionAlgorithmsClientToServer = new string[0],
- CompressionAlgorithmsServerToClient = new string[0],
- EncryptionAlgorithmsClientToServer = new string[0],
- EncryptionAlgorithmsServerToClient = new string[0],
- KeyExchangeAlgorithms = new[] { _keyExchangeAlgorithm },
- LanguagesClientToServer = new string[0],
- LanguagesServerToClient = new string[0],
- MacAlgorithmsClientToServer = new string[0],
- MacAlgorithmsServerToClient = new string[0],
- ServerHostKeyAlgorithms = new string[0]
- };
- var keyExchangeInit = keyExchangeInitMessage.GetPacket(8, null);
- _ = ServerSocket.Send(keyExchangeInit, 4, keyExchangeInit.Length - 4, SocketFlags.None);
+ if (!WaitForClientKeyExchangeInit)
+ {
+ SendKeyExchangeInit();
+ }
};
ServerListener.BytesReceived += (received, socket) =>
{
ServerBytesReceivedRegister.Add(received);
- if (!_authenticationStarted)
+ if (WaitForClientKeyExchangeInit && received.Length > 5 && received[5] == 20)
{
- var serviceAcceptMessage = ServiceAcceptMessageBuilder.Create(ServiceName.UserAuthentication)
- .Build();
- _ = ServerSocket.Send(serviceAcceptMessage, 0, serviceAcceptMessage.Length, SocketFlags.None);
-
- _authenticationStarted = true;
+ // This is the KEXINIT. Send one back.
+ SendKeyExchangeInit();
+ WaitForClientKeyExchangeInit = false;
}
};
ServerListener.Start();
@@ -162,6 +164,25 @@ protected virtual void SetupData()
ClientSocket = new DirectConnector(_socketFactory).Connect(ConnectionInfo);
CallSessionConnectWhenArrange = true;
+
+ void SendKeyExchangeInit()
+ {
+ var keyExchangeInitMessage = new KeyExchangeInitMessage
+ {
+ CompressionAlgorithmsClientToServer = new string[0],
+ CompressionAlgorithmsServerToClient = new string[0],
+ EncryptionAlgorithmsClientToServer = new string[0],
+ EncryptionAlgorithmsServerToClient = new string[0],
+ KeyExchangeAlgorithms = new[] { _keyExchangeAlgorithm },
+ LanguagesClientToServer = new string[0],
+ LanguagesServerToClient = new string[0],
+ MacAlgorithmsClientToServer = new string[0],
+ MacAlgorithmsServerToClient = new string[0],
+ ServerHostKeyAlgorithms = new string[0]
+ };
+ var keyExchangeInit = keyExchangeInitMessage.GetPacket(8, null);
+ _ = ServerSocket.Send(keyExchangeInit, 4, keyExchangeInit.Length - 4, SocketFlags.None);
+ }
}
private void CreateMocks()
@@ -187,7 +208,7 @@ private void SetupMocks()
_ = ServiceFactoryMock.Setup(p => p.CreateKeyExchange(ConnectionInfo.KeyExchangeAlgorithms, new[] { _keyExchangeAlgorithm })).Returns(_keyExchangeMock.Object);
_ = _keyExchangeMock.Setup(p => p.Name)
.Returns(_keyExchangeAlgorithm);
- _ = _keyExchangeMock.Setup(p => p.Start(Session, It.IsAny()));
+ _ = _keyExchangeMock.Setup(p => p.Start(Session, It.IsAny(), false));
_ = _keyExchangeMock.Setup(p => p.ExchangeHash)
.Returns(SessionId);
_ = _keyExchangeMock.Setup(p => p.CreateServerCipher())
diff --git a/test/Renci.SshNet.Tests/Classes/SessionTest_Connected_ServerAndClientDisconnectRace.cs b/test/Renci.SshNet.Tests/Classes/SessionTest_Connected_ServerAndClientDisconnectRace.cs
index 96797d727..11cda2d90 100644
--- a/test/Renci.SshNet.Tests/Classes/SessionTest_Connected_ServerAndClientDisconnectRace.cs
+++ b/test/Renci.SshNet.Tests/Classes/SessionTest_Connected_ServerAndClientDisconnectRace.cs
@@ -89,6 +89,13 @@ protected virtual void SetupData()
var newKeysMessage = new NewKeysMessage();
var newKeys = newKeysMessage.GetPacket(8, null);
_ = ServerSocket.Send(newKeys, 4, newKeys.Length - 4, SocketFlags.None);
+
+ if (!_authenticationStarted)
+ {
+ var serviceAcceptMessage = ServiceAcceptMessageBuilder.Create(ServiceName.UserAuthentication).Build();
+ _ = ServerSocket.Send(serviceAcceptMessage, 0, serviceAcceptMessage.Length, SocketFlags.None);
+ _authenticationStarted = true;
+ }
};
ServerListener = new AsyncSocketListener(_serverEndPoint);
@@ -118,13 +125,6 @@ protected virtual void SetupData()
ServerListener.BytesReceived += (received, socket) =>
{
ServerBytesReceivedRegister.Add(received);
-
- if (!_authenticationStarted)
- {
- var serviceAcceptMessage =ServiceAcceptMessageBuilder.Create(ServiceName.UserAuthentication).Build();
- _ = ServerSocket.Send(serviceAcceptMessage, 0, serviceAcceptMessage.Length, SocketFlags.None);
- _authenticationStarted = true;
- }
};
ServerListener.Start();
@@ -156,7 +156,7 @@ private void SetupMocks()
.Returns(_keyExchangeMock.Object);
_ = _keyExchangeMock.Setup(p => p.Name)
.Returns(_keyExchangeAlgorithm);
- _ = _keyExchangeMock.Setup(p => p.Start(Session, It.IsAny()));
+ _ = _keyExchangeMock.Setup(p => p.Start(Session, It.IsAny(), false));
_ = _keyExchangeMock.Setup(p => p.ExchangeHash)
.Returns(SessionId);
_ = _keyExchangeMock.Setup(p => p.CreateServerCipher())
diff --git a/test/Renci.SshNet.Tests/Classes/SessionTest_Connected_ServerDoesNotSendKexInit.cs b/test/Renci.SshNet.Tests/Classes/SessionTest_Connected_ServerDoesNotSendKexInit.cs
new file mode 100644
index 000000000..44bfa74fd
--- /dev/null
+++ b/test/Renci.SshNet.Tests/Classes/SessionTest_Connected_ServerDoesNotSendKexInit.cs
@@ -0,0 +1,24 @@
+using Microsoft.VisualStudio.TestTools.UnitTesting;
+
+namespace Renci.SshNet.Tests.Classes
+{
+ [TestClass]
+ public class SessionTest_Connected_ServerDoesNotSendKexInit : SessionTest_ConnectedBase
+ {
+ protected override void SetupData()
+ {
+ WaitForClientKeyExchangeInit = true;
+
+ base.SetupData();
+ }
+
+ protected override void Act()
+ {
+ }
+
+ [TestMethod]
+ public void ConnectShouldSucceed()
+ {
+ }
+ }
+}