diff --git a/build.proj b/build.proj index 0443bbcf4e..f2febc7dbe 100644 --- a/build.proj +++ b/build.proj @@ -58,9 +58,13 @@ + + + - + + @@ -220,6 +224,7 @@ -p:TestTargetOS=Windows$(TargetGroup) --collect "Code coverage" --results-directory $(ResultsDirectory) + --filter "category!=failing" --logger:"trx;LogFilePrefix=Unit-Windows$(TargetGroup)-$(TestSet)" $(TestCommand.Replace($([System.Environment]::NewLine), " ")) @@ -240,6 +245,7 @@ -p:TestTargetOS=Unixnetcoreapp --collect "Code coverage" --results-directory $(ResultsDirectory) + --filter "category!=failing" --logger:"trx;LogFilePrefix=Unit-Unixnetcoreapp-$(TestSet)" $(TestCommand.Replace($([System.Environment]::NewLine), " ")) diff --git a/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/SqlInternalConnectionTds.cs b/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/SqlInternalConnectionTds.cs index 38118dda72..c6eb35323c 100644 --- a/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/SqlInternalConnectionTds.cs +++ b/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/SqlInternalConnectionTds.cs @@ -1728,8 +1728,11 @@ private void LoginNoFailover(ServerInfo serverInfo, } if (_parser == null - || TdsParserState.Closed != _parser.State || IsDoNotRetryConnectError(sqlex) + // If state != closed, indicates that the parser encountered an error while processing the + // login response (e.g. an explicit error token). Transient network errors that impact + // connectivity will result in parser state being closed. + || TdsParserState.Closed != _parser.State || timeout.IsExpired) { // no more time to try again @@ -2008,6 +2011,9 @@ TimeoutTimer timeout throw; // Caller will call LoginFailure() } + // TODO: It doesn't make sense to connect to an azure sql server instance with a failover partner + // specified. Azure SQL Server does not support failover partners. Other availability technologies + // like Failover Groups should be used instead. if (!ADP.IsAzureSqlServerEndpoint(connectionOptions.DataSource) && IsConnectionDoomed) { throw; diff --git a/src/Microsoft.Data.SqlClient/tests/FunctionalTests/LocalizationTest.cs b/src/Microsoft.Data.SqlClient/tests/FunctionalTests/LocalizationTest.cs index 77e7eee950..af27dce721 100644 --- a/src/Microsoft.Data.SqlClient/tests/FunctionalTests/LocalizationTest.cs +++ b/src/Microsoft.Data.SqlClient/tests/FunctionalTests/LocalizationTest.cs @@ -6,6 +6,7 @@ using System.Collections.Generic; using System.Globalization; using System.Threading; +using Microsoft.SqlServer.TDS.Servers; using Xunit; namespace Microsoft.Data.SqlClient.Tests @@ -55,9 +56,9 @@ private string GetLocalizedErrorMessage(string culture) Thread.CurrentThread.CurrentCulture = new CultureInfo(culture); Thread.CurrentThread.CurrentUICulture = new CultureInfo(culture); - using TestTdsServer server = TestTdsServer.StartTestServer(); - var connStr = server.ConnectionString; - connStr = connStr.Replace("localhost", "dummy"); + using TdsServer server = new TdsServer(new TdsServerArguments() { }); + server.Start(); + var connStr = new SqlConnectionStringBuilder() { DataSource = $"dummy,{server.EndPoint.Port}" }.ConnectionString; using SqlConnection connection = new SqlConnection(connStr); try diff --git a/src/Microsoft.Data.SqlClient/tests/FunctionalTests/Microsoft.Data.SqlClient.FunctionalTests.csproj b/src/Microsoft.Data.SqlClient/tests/FunctionalTests/Microsoft.Data.SqlClient.FunctionalTests.csproj index 91a5a505b9..730d96ee19 100644 --- a/src/Microsoft.Data.SqlClient/tests/FunctionalTests/Microsoft.Data.SqlClient.FunctionalTests.csproj +++ b/src/Microsoft.Data.SqlClient/tests/FunctionalTests/Microsoft.Data.SqlClient.FunctionalTests.csproj @@ -36,8 +36,6 @@ - - @@ -64,8 +62,6 @@ - - @@ -91,6 +87,7 @@ + diff --git a/src/Microsoft.Data.SqlClient/tests/FunctionalTests/SqlConnectionReadOnlyRoutingTests.cs b/src/Microsoft.Data.SqlClient/tests/FunctionalTests/SqlConnectionReadOnlyRoutingTests.cs deleted file mode 100644 index c3574dbc13..0000000000 --- a/src/Microsoft.Data.SqlClient/tests/FunctionalTests/SqlConnectionReadOnlyRoutingTests.cs +++ /dev/null @@ -1,140 +0,0 @@ -// Licensed to the .NET Foundation under one or more agreements. -// The .NET Foundation licenses this file to you under the MIT license. -// See the LICENSE file in the project root for more information. - -using System; -using System.Collections.Generic; -using System.Net; -using System.Threading.Tasks; -using Microsoft.SqlServer.TDS.Servers; -using Xunit; - -namespace Microsoft.Data.SqlClient.Tests -{ - public class SqlConnectionReadOnlyRoutingTests - { - [Fact] - public void NonRoutedConnection() - { - using TestTdsServer server = TestTdsServer.StartTestServer(); - SqlConnectionStringBuilder builder = new SqlConnectionStringBuilder(server.ConnectionString) { ApplicationIntent = ApplicationIntent.ReadOnly }; - using SqlConnection connection = new SqlConnection(builder.ConnectionString); - connection.Open(); - } - - [Fact] - public async Task NonRoutedAsyncConnection() - { - using TestTdsServer server = TestTdsServer.StartTestServer(); - SqlConnectionStringBuilder builder = new SqlConnectionStringBuilder(server.ConnectionString) { ApplicationIntent = ApplicationIntent.ReadOnly }; - using SqlConnection connection = new SqlConnection(builder.ConnectionString); - await connection.OpenAsync(); - } - - [Fact] - public void RoutedConnection() - => RecursivelyRoutedConnection(1); - - [Fact] - public async Task RoutedAsyncConnection() - => await RecursivelyRoutedAsyncConnection(1); - - [Theory] - [InlineData(2)] - [InlineData(9)] - [InlineData(11)] // The driver rejects more than 10 redirects (11 layers of redirecting servers) - public void RecursivelyRoutedConnection(int layers) - { - TestTdsServer innerServer = TestTdsServer.StartTestServer(); - IPEndPoint lastEndpoint = innerServer.Endpoint; - Stack routingLayers = new(layers + 1); - string lastConnectionString = innerServer.ConnectionString; - - try - { - routingLayers.Push(innerServer); - for (int i = 0; i < layers; i++) - { - TestRoutingTdsServer router = TestRoutingTdsServer.StartTestServer(lastEndpoint); - - routingLayers.Push(router); - lastEndpoint = router.Endpoint; - lastConnectionString = router.ConnectionString; - } - - SqlConnectionStringBuilder builder = new SqlConnectionStringBuilder(lastConnectionString) { ApplicationIntent = ApplicationIntent.ReadOnly }; - using SqlConnection connection = new SqlConnection(builder.ConnectionString); - connection.Open(); - } - finally - { - while (routingLayers.Count > 0) - { - GenericTDSServer layer = routingLayers.Pop(); - - if (layer is IDisposable disp) - { - disp.Dispose(); - } - } - } - } - - [Theory] - [InlineData(2)] - [InlineData(9)] - [InlineData(11)] // The driver rejects more than 10 redirects (11 layers of redirecting servers) - public async Task RecursivelyRoutedAsyncConnection(int layers) - { - TestTdsServer innerServer = TestTdsServer.StartTestServer(); - IPEndPoint lastEndpoint = innerServer.Endpoint; - Stack routingLayers = new(layers + 1); - string lastConnectionString = innerServer.ConnectionString; - - try - { - routingLayers.Push(innerServer); - for (int i = 0; i < layers; i++) - { - TestRoutingTdsServer router = TestRoutingTdsServer.StartTestServer(lastEndpoint); - - routingLayers.Push(router); - lastEndpoint = router.Endpoint; - lastConnectionString = router.ConnectionString; - } - - SqlConnectionStringBuilder builder = new SqlConnectionStringBuilder(lastConnectionString) { ApplicationIntent = ApplicationIntent.ReadOnly }; - using SqlConnection connection = new SqlConnection(builder.ConnectionString); - await connection.OpenAsync(); - } - finally - { - while (routingLayers.Count > 0) - { - GenericTDSServer layer = routingLayers.Pop(); - - if (layer is IDisposable disp) - { - disp.Dispose(); - } - } - } - } - - [Fact] - public void ConnectionRoutingLimit() - { - SqlException sqlEx = Assert.Throws(() => RecursivelyRoutedConnection(12)); // This will fail on the 11th redirect - - Assert.Contains("Too many redirections have occurred.", sqlEx.Message, StringComparison.InvariantCultureIgnoreCase); - } - - [Fact] - public async Task AsyncConnectionRoutingLimit() - { - SqlException sqlEx = await Assert.ThrowsAsync(() => RecursivelyRoutedAsyncConnection(12)); // This will fail on the 11th redirect - - Assert.Contains("Too many redirections have occurred.", sqlEx.Message, StringComparison.InvariantCultureIgnoreCase); - } - } -} diff --git a/src/Microsoft.Data.SqlClient/tests/FunctionalTests/TestRoutingTdsServer.cs b/src/Microsoft.Data.SqlClient/tests/FunctionalTests/TestRoutingTdsServer.cs deleted file mode 100644 index 130b50cad9..0000000000 --- a/src/Microsoft.Data.SqlClient/tests/FunctionalTests/TestRoutingTdsServer.cs +++ /dev/null @@ -1,64 +0,0 @@ -// Licensed to the .NET Foundation under one or more agreements. -// The .NET Foundation licenses this file to you under the MIT license. -// See the LICENSE file in the project root for more information. - -using System; -using System.Net; -using System.Runtime.CompilerServices; -using Microsoft.SqlServer.TDS.EndPoint; -using Microsoft.SqlServer.TDS.Servers; - -namespace Microsoft.Data.SqlClient.Tests -{ - internal class TestRoutingTdsServer : RoutingTDSServer, IDisposable - { - private const int DefaultConnectionTimeout = 5; - - private TDSServerEndPoint _endpoint = null; - - private SqlConnectionStringBuilder _connectionStringBuilder; - - public TestRoutingTdsServer(RoutingTDSServerArguments args) : base(args) { } - - public static TestRoutingTdsServer StartTestServer(IPEndPoint destinationEndpoint, bool enableFedAuth = false, bool enableLog = false, int connectionTimeout = DefaultConnectionTimeout, bool excludeEncryption = false, [CallerMemberName] string methodName = "") - { - RoutingTDSServerArguments args = new RoutingTDSServerArguments() - { - Log = enableLog ? Console.Out : null, - RoutingTCPHost = destinationEndpoint.Address.ToString() == IPAddress.Any.ToString() ? IPAddress.Loopback.ToString() : destinationEndpoint.Address.ToString(), - RoutingTCPPort = (ushort)destinationEndpoint.Port, - }; - - if (enableFedAuth) - { - args.FedAuthRequiredPreLoginOption = SqlServer.TDS.PreLogin.TdsPreLoginFedAuthRequiredOption.FedAuthRequired; - } - if (excludeEncryption) - { - args.Encryption = SqlServer.TDS.PreLogin.TDSPreLoginTokenEncryptionType.None; - } - - TestRoutingTdsServer server = new TestRoutingTdsServer(args); - server._endpoint = new TDSServerEndPoint(server) { ServerEndPoint = new IPEndPoint(IPAddress.Any, 0) }; - server._endpoint.EndpointName = methodName; - // The server EventLog should be enabled as it logs the exceptions. - server._endpoint.EventLog = enableLog ? Console.Out : null; - server._endpoint.Start(); - - int port = server._endpoint.ServerEndPoint.Port; - server._connectionStringBuilder = excludeEncryption - // Allow encryption to be set when encryption is to be excluded from pre-login response. - ? new SqlConnectionStringBuilder() { DataSource = "localhost," + port, ConnectTimeout = connectionTimeout, Encrypt = SqlConnectionEncryptOption.Mandatory } - : new SqlConnectionStringBuilder() { DataSource = "localhost," + port, ConnectTimeout = connectionTimeout, Encrypt = SqlConnectionEncryptOption.Optional }; - server.ConnectionString = server._connectionStringBuilder.ConnectionString; - server.Endpoint = server._endpoint.ServerEndPoint; - return server; - } - - public void Dispose() => _endpoint?.Stop(); - - public string ConnectionString { get; private set; } - - public IPEndPoint Endpoint { get; private set; } - } -} diff --git a/src/Microsoft.Data.SqlClient/tests/FunctionalTests/TestTdsServer.cs b/src/Microsoft.Data.SqlClient/tests/FunctionalTests/TestTdsServer.cs deleted file mode 100644 index a5976fd6d5..0000000000 --- a/src/Microsoft.Data.SqlClient/tests/FunctionalTests/TestTdsServer.cs +++ /dev/null @@ -1,76 +0,0 @@ -// Licensed to the .NET Foundation under one or more agreements. -// The .NET Foundation licenses this file to you under the MIT license. -// See the LICENSE file in the project root for more information. - -using System; -using System.Net; -using System.Runtime.CompilerServices; -using Microsoft.SqlServer.TDS.EndPoint; -using Microsoft.SqlServer.TDS.Servers; - -namespace Microsoft.Data.SqlClient.Tests -{ - internal class TestTdsServer : GenericTDSServer, IDisposable - { - private const int DefaultConnectionTimeout = 5; - - private TDSServerEndPoint _endpoint = null; - - private SqlConnectionStringBuilder _connectionStringBuilder; - - public TestTdsServer(TDSServerArguments args) : base(args) { } - - public TestTdsServer(QueryEngine engine, TDSServerArguments args) : base(args) - { - Engine = engine; - } - - public static TestTdsServer StartServerWithQueryEngine(QueryEngine engine, bool enableFedAuth = false, bool enableLog = false, int connectionTimeout = DefaultConnectionTimeout, bool excludeEncryption = false, Version serverVersion = null, [CallerMemberName] string methodName = "") - { - TDSServerArguments args = new TDSServerArguments() - { - Log = enableLog ? Console.Out : null, - }; - - if (enableFedAuth) - { - args.FedAuthRequiredPreLoginOption = SqlServer.TDS.PreLogin.TdsPreLoginFedAuthRequiredOption.FedAuthRequired; - } - if (excludeEncryption) - { - args.Encryption = SqlServer.TDS.PreLogin.TDSPreLoginTokenEncryptionType.None; - } - if (serverVersion != null) - { - args.ServerVersion = serverVersion; - } - - TestTdsServer server = engine == null ? new TestTdsServer(args) : new TestTdsServer(engine, args); - server._endpoint = new TDSServerEndPoint(server) { ServerEndPoint = new IPEndPoint(IPAddress.Any, 0) }; - server._endpoint.EndpointName = methodName; - // The server EventLog should be enabled as it logs the exceptions. - server._endpoint.EventLog = enableLog ? Console.Out : null; - server._endpoint.Start(); - - int port = server._endpoint.ServerEndPoint.Port; - server._connectionStringBuilder = excludeEncryption - // Allow encryption to be set when encryption is to be excluded from pre-login response. - ? new SqlConnectionStringBuilder() { DataSource = "localhost," + port, ConnectTimeout = connectionTimeout, Encrypt = SqlConnectionEncryptOption.Mandatory } - : new SqlConnectionStringBuilder() { DataSource = "localhost," + port, ConnectTimeout = connectionTimeout, Encrypt = SqlConnectionEncryptOption.Optional }; - server.ConnectionString = server._connectionStringBuilder.ConnectionString; - server.Endpoint = server._endpoint.ServerEndPoint; - return server; - } - - public static TestTdsServer StartTestServer(bool enableFedAuth = false, bool enableLog = false, int connectionTimeout = DefaultConnectionTimeout, bool excludeEncryption = false, Version serverVersion = null, [CallerMemberName] string methodName = "") - { - return StartServerWithQueryEngine(null, enableFedAuth, enableLog, connectionTimeout, excludeEncryption, serverVersion, methodName); - } - - public void Dispose() => _endpoint?.Stop(); - - public string ConnectionString { get; private set; } - - public IPEndPoint Endpoint { get; private set; } - } -} diff --git a/src/Microsoft.Data.SqlClient/tests/ManualTests/Microsoft.Data.SqlClient.ManualTesting.Tests.csproj b/src/Microsoft.Data.SqlClient/tests/ManualTests/Microsoft.Data.SqlClient.ManualTesting.Tests.csproj index 7bd503c5b9..312a03d0b8 100644 --- a/src/Microsoft.Data.SqlClient/tests/ManualTests/Microsoft.Data.SqlClient.ManualTesting.Tests.csproj +++ b/src/Microsoft.Data.SqlClient/tests/ManualTests/Microsoft.Data.SqlClient.ManualTesting.Tests.csproj @@ -300,7 +300,6 @@ - diff --git a/src/Microsoft.Data.SqlClient/tests/ManualTests/SQL/ConnectionTestWithSSLCert/CertificateTestWithTdsServer.cs b/src/Microsoft.Data.SqlClient/tests/ManualTests/SQL/ConnectionTestWithSSLCert/CertificateTestWithTdsServer.cs index 48b0c9273e..c64a05dcbf 100644 --- a/src/Microsoft.Data.SqlClient/tests/ManualTests/SQL/ConnectionTestWithSSLCert/CertificateTestWithTdsServer.cs +++ b/src/Microsoft.Data.SqlClient/tests/ManualTests/SQL/ConnectionTestWithSSLCert/CertificateTestWithTdsServer.cs @@ -12,6 +12,7 @@ using System.ServiceProcess; using System.Text; using Microsoft.Data.SqlClient.ManualTesting.Tests.DataCommon; +using Microsoft.SqlServer.TDS.Servers; using Microsoft.Win32; using Xunit; @@ -129,18 +130,23 @@ private void ConnectionTest(ConnectionTestParameters connectionTestParameters) string userId = string.IsNullOrWhiteSpace(builder.UserID) ? "user" : builder.UserID; string password = string.IsNullOrWhiteSpace(builder.Password) ? "password" : builder.Password; - using TestTdsServer server = TestTdsServer.StartTestServer(enableFedAuth: false, enableLog: false, connectionTimeout: 15, - methodName: "", -#if NET9_0_OR_GREATER - X509CertificateLoader.LoadPkcs12FromFile(s_fullPathToPfx, "nopassword", X509KeyStorageFlags.UserKeySet), -#else - new X509Certificate2(s_fullPathToPfx, "nopassword", X509KeyStorageFlags.UserKeySet), -#endif - encryptionProtocols: connectionTestParameters.EncryptionProtocols, - encryptionType: connectionTestParameters.TdsEncryptionType); - - builder = new(server.ConnectionString) + using TdsServer server = new TdsServer(new TdsServerArguments { + #if NET9_0_OR_GREATER + EncryptionCertificate = X509CertificateLoader.LoadPkcs12FromFile(s_fullPathToPfx, "nopassword", X509KeyStorageFlags.UserKeySet), + #else + EncryptionCertificate = new X509Certificate2(s_fullPathToPfx, "nopassword", X509KeyStorageFlags.UserKeySet), + #endif + EncryptionProtocols = connectionTestParameters.EncryptionProtocols, + Encryption = connectionTestParameters.TdsEncryptionType, + }); + + server.Start(); + + builder = new() + { + DataSource = $"localhost,{server.EndPoint.Port}", + ConnectTimeout = 15, UserID = userId, Password = password, TrustServerCertificate = connectionTestParameters.TrustServerCertificate, diff --git a/src/Microsoft.Data.SqlClient/tests/ManualTests/SQL/ExceptionTest/ConnectionExceptionTest.cs b/src/Microsoft.Data.SqlClient/tests/ManualTests/SQL/ExceptionTest/ConnectionExceptionTest.cs index 6ee0681a0d..c44ed97ed0 100644 --- a/src/Microsoft.Data.SqlClient/tests/ManualTests/SQL/ExceptionTest/ConnectionExceptionTest.cs +++ b/src/Microsoft.Data.SqlClient/tests/ManualTests/SQL/ExceptionTest/ConnectionExceptionTest.cs @@ -3,6 +3,7 @@ // See the LICENSE file in the project root for more information. using System; +using Microsoft.SqlServer.TDS.Servers; using Xunit; namespace Microsoft.Data.SqlClient.ManualTesting.Tests @@ -23,8 +24,14 @@ public class ConnectionExceptionTest [ConditionalFact(nameof(IsNotKerberos))] public void TestConnectionStateWithErrorClass20() { - using TestTdsServer server = TestTdsServer.StartTestServer(); - using SqlConnection conn = new(server.ConnectionString); + using TdsServer server = new TdsServer(); + server.Start(); + using SqlConnection conn = new( + new SqlConnectionStringBuilder + { + DataSource = $"localhost,{server.EndPoint.Port}", + Encrypt = SqlConnectionEncryptOption.Optional + }.ConnectionString); conn.Open(); SqlCommand cmd = conn.CreateCommand(); diff --git a/src/Microsoft.Data.SqlClient/tests/ManualTests/TracingTests/DiagnosticTest.cs b/src/Microsoft.Data.SqlClient/tests/ManualTests/TracingTests/DiagnosticTest.cs index b8649d43d2..4ae426fcbb 100644 --- a/src/Microsoft.Data.SqlClient/tests/ManualTests/TracingTests/DiagnosticTest.cs +++ b/src/Microsoft.Data.SqlClient/tests/ManualTests/TracingTests/DiagnosticTest.cs @@ -483,7 +483,7 @@ public void ConnectionOpenAsyncErrorTest() }).Dispose(); } - private static void CollectStatisticsDiagnostics(Action sqlOperation, bool enableServerLogging = false, [CallerMemberName] string methodName = "") + private static void CollectStatisticsDiagnostics(Action sqlOperation, [CallerMemberName] string methodName = "") { bool statsLogged = false; bool operationHasError = false; @@ -670,10 +670,19 @@ private static void CollectStatisticsDiagnostics(Action sqlOperation, bo { Console.WriteLine(string.Format("Test: {0} Enabled Listeners", methodName)); - using (var server = TestTdsServer.StartServerWithQueryEngine(new DiagnosticsQueryEngine(), enableLog: enableServerLogging, methodName: methodName)) + + using (var server = new TdsServer(new DiagnosticsQueryEngine(), new TdsServerArguments())) { + server.Start(methodName); Console.WriteLine(string.Format("Test: {0} Started Server", methodName)); - sqlOperation(server.ConnectionString); + + var connectionString = new SqlConnectionStringBuilder + { + DataSource = $"localhost,{server.EndPoint.Port}", + Encrypt = SqlConnectionEncryptOption.Optional + }.ConnectionString; + + sqlOperation(connectionString); Console.WriteLine(string.Format("Test: {0} SqlOperation Successful", methodName)); @@ -859,11 +868,17 @@ private static async Task CollectStatisticsDiagnosticsAsync(Func s using (DiagnosticListener.AllListeners.Subscribe(diagnosticListenerObserver)) { Console.WriteLine(string.Format("Test: {0} Enabled Listeners", methodName)); - using (var server = TestTdsServer.StartServerWithQueryEngine(new DiagnosticsQueryEngine(), methodName: methodName)) + using (var server = new TdsServer(new DiagnosticsQueryEngine(), new TdsServerArguments())) { + server.Start(methodName); Console.WriteLine(string.Format("Test: {0} Started Server", methodName)); - await sqlOperation(server.ConnectionString); + var connectionString = new SqlConnectionStringBuilder + { + DataSource = $"localhost,{server.EndPoint.Port}", + Encrypt = SqlConnectionEncryptOption.Optional + }.ConnectionString; + await sqlOperation(connectionString); Console.WriteLine(string.Format("Test: {0} SqlOperation Successful", methodName)); @@ -890,7 +905,7 @@ private static T GetPropertyValueFromType(object obj, string propName) public class DiagnosticsQueryEngine : QueryEngine { - public DiagnosticsQueryEngine() : base(new TDSServerArguments()) + public DiagnosticsQueryEngine() : base(new TdsServerArguments()) { } diff --git a/src/Microsoft.Data.SqlClient/tests/ManualTests/TracingTests/TestTdsServer.cs b/src/Microsoft.Data.SqlClient/tests/ManualTests/TracingTests/TestTdsServer.cs deleted file mode 100644 index 45a817c46e..0000000000 --- a/src/Microsoft.Data.SqlClient/tests/ManualTests/TracingTests/TestTdsServer.cs +++ /dev/null @@ -1,93 +0,0 @@ -// Licensed to the .NET Foundation under one or more agreements. -// The .NET Foundation licenses this file to you under the MIT license. -// See the LICENSE file in the project root for more information. - -using System; -using System.Linq; -using System.Net; -using System.Net.Sockets; -using System.Runtime.CompilerServices; -using System.Security.Authentication; -using System.Security.Cryptography.X509Certificates; -using Microsoft.SqlServer.TDS.EndPoint; -using Microsoft.SqlServer.TDS.PreLogin; -using Microsoft.SqlServer.TDS.Servers; - -namespace Microsoft.Data.SqlClient.ManualTesting.Tests -{ - internal class TestTdsServer : GenericTDSServer, IDisposable - { - private const int DefaultConnectionTimeout = 5; - - private TDSServerEndPoint _endpoint = null; - - private SqlConnectionStringBuilder _connectionStringBuilder; - - public TestTdsServer(TDSServerArguments args) : base(args) { } - - public TestTdsServer(QueryEngine engine, TDSServerArguments args) : base(args) - { - Engine = engine; - } - - public static TestTdsServer StartServerWithQueryEngine(QueryEngine engine, bool enableFedAuth = false, bool enableLog = false, - int connectionTimeout = DefaultConnectionTimeout, [CallerMemberName] string methodName = "", - X509Certificate2 encryptionCertificate = null, SslProtocols encryptionProtocols = SslProtocols.Tls12, TDSPreLoginTokenEncryptionType encryptionType = TDSPreLoginTokenEncryptionType.NotSupported) - { - TDSServerArguments args = new TDSServerArguments() - { - Log = enableLog ? Console.Out : null, - }; - - if (enableFedAuth) - { - args.FedAuthRequiredPreLoginOption = SqlServer.TDS.PreLogin.TdsPreLoginFedAuthRequiredOption.FedAuthRequired; - } - - args.EncryptionCertificate = encryptionCertificate; - args.EncryptionProtocols = encryptionProtocols; - args.Encryption = encryptionType; - - TestTdsServer server = engine == null ? new TestTdsServer(args) : new TestTdsServer(engine, args); - - server._endpoint = new TDSServerEndPoint(server) { ServerEndPoint = new IPEndPoint(IPAddress.Any, 0) }; - server._endpoint.EndpointName = methodName; - // The server EventLog should be enabled as it logs the exceptions. - server._endpoint.EventLog = enableLog ? Console.Out : null; - server._endpoint.Start(); - - int port = server._endpoint.ServerEndPoint.Port; - - server._connectionStringBuilder = new SqlConnectionStringBuilder() - { - DataSource = "localhost," + port, - ConnectTimeout = connectionTimeout, - }; - - if (encryptionType == TDSPreLoginTokenEncryptionType.Off || - encryptionType == TDSPreLoginTokenEncryptionType.None || - encryptionType == TDSPreLoginTokenEncryptionType.NotSupported) - { - server._connectionStringBuilder.Encrypt = SqlConnectionEncryptOption.Optional; - } - else - { - server._connectionStringBuilder.Encrypt = SqlConnectionEncryptOption.Mandatory; - } - - server.ConnectionString = server._connectionStringBuilder.ConnectionString; - return server; - } - - public static TestTdsServer StartTestServer(bool enableFedAuth = false, bool enableLog = false, - int connectionTimeout = DefaultConnectionTimeout, [CallerMemberName] string methodName = "", - X509Certificate2 encryptionCertificate = null, SslProtocols encryptionProtocols = SslProtocols.Tls12, TDSPreLoginTokenEncryptionType encryptionType = TDSPreLoginTokenEncryptionType.NotSupported) - { - return StartServerWithQueryEngine(null, enableFedAuth, enableLog, connectionTimeout, methodName, encryptionCertificate, encryptionProtocols, encryptionType); - } - - public void Dispose() => _endpoint?.Stop(); - - public string ConnectionString { get; private set; } - } -} diff --git a/src/Microsoft.Data.SqlClient/tests/UnitTests/Microsoft.Data.SqlClient.UnitTests.csproj b/src/Microsoft.Data.SqlClient/tests/UnitTests/Microsoft.Data.SqlClient.UnitTests.csproj index 2f0e12c922..44392c2f81 100644 --- a/src/Microsoft.Data.SqlClient/tests/UnitTests/Microsoft.Data.SqlClient.UnitTests.csproj +++ b/src/Microsoft.Data.SqlClient/tests/UnitTests/Microsoft.Data.SqlClient.UnitTests.csproj @@ -10,6 +10,7 @@ + runtime; build; native; contentfiles; analyzers; buildtransitive @@ -25,6 +26,9 @@ all + + + diff --git a/src/Microsoft.Data.SqlClient/tests/UnitTests/SimulatedServerTests/ConnectionFailoverTests.cs b/src/Microsoft.Data.SqlClient/tests/UnitTests/SimulatedServerTests/ConnectionFailoverTests.cs new file mode 100644 index 0000000000..ce2318b569 --- /dev/null +++ b/src/Microsoft.Data.SqlClient/tests/UnitTests/SimulatedServerTests/ConnectionFailoverTests.cs @@ -0,0 +1,534 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using System.Data; +using Microsoft.SqlServer.TDS.Servers; +using Xunit; + +namespace Microsoft.Data.SqlClient.ScenarioTests +{ + public class ConnectionFailoverTests + { + //TODO parameterize for transient errors + [Theory] + [InlineData(40613)] + [InlineData(42108)] + [InlineData(42109)] + public void TransientFault_NoFailover_DoesNotClearPool(uint errorCode) + { + // When connecting to a server with a configured failover partner, + // transient errors returned during the login ack should not clear the connection pool. + + // Arrange + using TdsServer failoverServer = new TdsServer(new TdsServerArguments + { + // Doesn't need to point to a real endpoint, just needs a value specified + FailoverPartner = "localhost,1234" + }); + failoverServer.Start(); + var failoverDataSource = $"localhost,{failoverServer.EndPoint.Port}"; + + // Errors are off to start to allow the pool to warm up + using TransientFaultTdsServer initialServer = new TransientFaultTdsServer(new TransientFaultTdsServerArguments + { + FailoverPartner = failoverDataSource + }); + initialServer.Start(); + + SqlConnectionStringBuilder builder = new() + { + DataSource = "localhost," + initialServer.EndPoint.Port, + ConnectRetryInterval = 1, + ConnectTimeout = 30, + Encrypt = SqlConnectionEncryptOption.Optional, + InitialCatalog = "test" + }; + + using SqlConnection connection = new(builder.ConnectionString); + connection.Open(); + + // Act + initialServer.SetErrorBehavior(true, errorCode); + using SqlConnection secondConnection = new(builder.ConnectionString); + // Should not trigger a failover, will retry against the same server + secondConnection.Open(); + + // Request a new connection, should initiate a fresh connection attempt if the pool was cleared. + connection.Close(); + connection.Open(); + + // Assert + Assert.Equal(ConnectionState.Open, connection.State); + Assert.Equal(ConnectionState.Open, secondConnection.State); + Assert.Equal($"localhost,{initialServer.EndPoint.Port}", connection.DataSource); + Assert.Equal($"localhost,{initialServer.EndPoint.Port}", secondConnection.DataSource); + + // 1 for the initial connection, 2 for the second connection + Assert.Equal(3, initialServer.PreLoginCount); + // A failover should not be triggered, so prelogin count to the failover server should be 0 + Assert.Equal(0, failoverServer.PreLoginCount); + } + + [Fact] + public void NetworkError_TriggersFailover_ClearsPool() + { + // When connecting to a server with a configured failover partner, + // network errors returned during prelogin should clear the connection pool. + + // Arrange + using TdsServer failoverServer = new TdsServer(new TdsServerArguments + { + // Doesn't need to point to a real endpoint, just needs a value specified + FailoverPartner = "localhost,1234" + }); + failoverServer.Start(); + var failoverDataSource = $"localhost,{failoverServer.EndPoint.Port}"; + + // Errors are off to start to allow the pool to warm up + using TransientFaultTdsServer initialServer = new TransientFaultTdsServer(new TransientFaultTdsServerArguments + { + FailoverPartner = failoverDataSource + }); + initialServer.Start(); + + SqlConnectionStringBuilder builder = new() + { + DataSource = "localhost," + initialServer.EndPoint.Port, + ConnectRetryInterval = 1, + ConnectTimeout = 30, + Encrypt = SqlConnectionEncryptOption.Optional, + InitialCatalog = "test" + }; + + // Open the initial connection to warm up the pool and populate failover partner information + // for the pool group. + using SqlConnection connection = new(builder.ConnectionString); + connection.Open(); + Assert.Equal(ConnectionState.Open, connection.State); + Assert.Equal($"localhost,{initialServer.EndPoint.Port}", connection.DataSource); + Assert.Equal(1, initialServer.PreLoginCount); + Assert.Equal(0, failoverServer.PreLoginCount); + + // Act + // Should trigger a failover because the initial server is unavailable + initialServer.Dispose(); + using SqlConnection secondConnection = new(builder.ConnectionString); + secondConnection.Open(); + + // Assert + Assert.Equal(ConnectionState.Open, secondConnection.State); + Assert.Equal($"localhost,{failoverServer.EndPoint.Port}", secondConnection.DataSource); + Assert.Equal(1, initialServer.PreLoginCount); + Assert.Equal(1, failoverServer.PreLoginCount); + + + // Act + // Request a new connection, should initiate a fresh connection attempt if the pool was cleared. + connection.Close(); + connection.Open(); + + // Assert + Assert.Equal(ConnectionState.Open, connection.State); + Assert.Equal($"localhost,{failoverServer.EndPoint.Port}", connection.DataSource); + Assert.Equal(1, initialServer.PreLoginCount); + Assert.Equal(2, failoverServer.PreLoginCount); + } + + [ActiveIssue("https://github.com/dotnet/SqlClient/issues/3527")] + [Fact] + public void NetworkError_RetryDisabled_ShouldFail() + { + using TdsServer failoverServer = new TdsServer( + new TdsServerArguments + { + // Doesn't need to point to a real endpoint, just needs a value specified + FailoverPartner = "localhost,1234", + }); + failoverServer.Start(); + + // Arrange + using TransientDelayTdsServer server = new TransientDelayTdsServer( + new TransientDelayTdsServerArguments() + { + IsEnabledTransientTimeout = true, + SleepDuration = TimeSpan.FromMilliseconds(1000), + FailoverPartner = $"localhost,{failoverServer.EndPoint.Port}", + }); + server.Start(); + + SqlConnectionStringBuilder builder = new SqlConnectionStringBuilder() + { + DataSource = "localhost," + server.EndPoint.Port, + InitialCatalog = "master",// Required for failover partner to work + ConnectTimeout = 5, + ConnectRetryInterval = 1, + ConnectRetryCount = 0, // Disable retry + Encrypt = false, + }; + using SqlConnection connection = new(builder.ConnectionString); + + // Act + Assert.Throws(() => connection.Open()); + + // Assert + // On the first connection attempt, no failover partner information is available, + // so the connection will retry on the same server. + Assert.Equal(ConnectionState.Closed, connection.State); + Assert.Equal(1, server.PreLoginCount); + Assert.Equal(0, failoverServer.PreLoginCount); + } + + [ActiveIssue("https://github.com/dotnet/SqlClient/issues/3528")] + [Fact] + public void NetworkError_RetryEnabled_ShouldConnectToPrimary() + { + using TdsServer failoverServer = new TdsServer( + new TdsServerArguments + { + // Doesn't need to point to a real endpoint, just needs a value specified + FailoverPartner = "localhost,1234", + }); + failoverServer.Start(); + + // Arrange + using TransientDelayTdsServer server = new TransientDelayTdsServer( + new TransientDelayTdsServerArguments() + { + IsEnabledTransientTimeout = true, + SleepDuration = TimeSpan.FromMilliseconds(1000), + FailoverPartner = $"localhost,{failoverServer.EndPoint.Port}", + }); + server.Start(); + + SqlConnectionStringBuilder builder = new SqlConnectionStringBuilder() + { + DataSource = "localhost," + server.EndPoint.Port, + InitialCatalog = "master",// Required for failover partner to work + ConnectTimeout = 5, + ConnectRetryInterval = 1, + Encrypt = false, + }; + using SqlConnection connection = new(builder.ConnectionString); + try + { + // Act + connection.Open(); + } + catch (Exception e) + { + Assert.Fail(e.Message); + } + + // Assert + // On the first connection attempt, no failover partner information is available, + // so the connection will retry on the same server. + Assert.Equal(ConnectionState.Open, connection.State); + Assert.Equal($"localhost,{server.EndPoint.Port}", connection.DataSource); + Assert.Equal(2, server.PreLoginCount); + Assert.Equal(0, failoverServer.PreLoginCount); + } + + [Fact] + public void NetworkError_WithUserProvidedPartner_RetryDisabled_ShouldConnectToFailoverPartner() + { + using TdsServer failoverServer = new TdsServer( + new TdsServerArguments + { + // Doesn't need to point to a real endpoint, just needs a value specified + FailoverPartner = "localhost,1234", + }); + failoverServer.Start(); + + // Arrange + using TransientDelayTdsServer server = new TransientDelayTdsServer( + new TransientDelayTdsServerArguments() + { + IsEnabledTransientTimeout = true, + SleepDuration = TimeSpan.FromMilliseconds(5000), + FailoverPartner = $"localhost,{failoverServer.EndPoint.Port}", + }); + server.Start(); + + SqlConnectionStringBuilder builder = new SqlConnectionStringBuilder() + { + DataSource = "localhost," + server.EndPoint.Port, + InitialCatalog = "master", // Required for failover partner to work + ConnectTimeout = 5, + ConnectRetryInterval = 1, + ConnectRetryCount = 0, // Disable retry + FailoverPartner = $"localhost,{failoverServer.EndPoint.Port}", // User provided failover partner + Encrypt = false, + }; + using SqlConnection connection = new(builder.ConnectionString); + try + { + // Act + connection.Open(); + } + catch (Exception e) + { + Assert.Fail(e.Message); + } + + // Assert + // On the first connection attempt, failover partner information is available in the connection string, + // so the connection will retry on the failover server. + Assert.Equal(ConnectionState.Open, connection.State); + Assert.Equal($"localhost,{failoverServer.EndPoint.Port}", connection.DataSource); + Assert.Equal(1, server.PreLoginCount); + Assert.Equal(1, failoverServer.PreLoginCount); + } + + [Fact] + public void NetworkError_WithUserProvidedPartner_RetryEnabled_ShouldConnectToFailoverPartner() + { + using TdsServer failoverServer = new TdsServer( + new TdsServerArguments + { + // Doesn't need to point to a real endpoint, just needs a value specified + FailoverPartner = "localhost,1234", + }); + failoverServer.Start(); + + // Arrange + using TransientDelayTdsServer server = new TransientDelayTdsServer( + new TransientDelayTdsServerArguments() + { + IsEnabledTransientTimeout = true, + SleepDuration = TimeSpan.FromMilliseconds(1000), + FailoverPartner = $"localhost,{failoverServer.EndPoint.Port}", + }); + server.Start(); + + SqlConnectionStringBuilder builder = new SqlConnectionStringBuilder() + { + DataSource = "localhost," + server.EndPoint.Port, + InitialCatalog = "master", // Required for failover partner to work + ConnectTimeout = 5, + ConnectRetryInterval = 1, + FailoverPartner = $"localhost,{failoverServer.EndPoint.Port}", // User provided failover partner + Encrypt = false, + }; + using SqlConnection connection = new(builder.ConnectionString); + try + { + // Act + connection.Open(); + } + catch (Exception e) + { + Assert.Fail(e.Message); + } + + // Assert + // On the first connection attempt, failover partner information is available in the connection string, + // so the connection will retry on the failover server. + Assert.Equal(ConnectionState.Open, connection.State); + Assert.Equal($"localhost,{failoverServer.EndPoint.Port}", connection.DataSource); + Assert.Equal(1, server.PreLoginCount); + Assert.Equal(1, failoverServer.PreLoginCount); + } + + [Theory] + [InlineData(40613)] + [InlineData(42108)] + [InlineData(42109)] + public void TransientFault_ShouldConnectToPrimary(uint errorCode) + { + // Arrange + using TdsServer failoverServer = new TdsServer( + new TdsServerArguments + { + // Doesn't need to point to a real endpoint, just needs a value specified + FailoverPartner = "localhost:1234", + }); + failoverServer.Start(); + + using TransientFaultTdsServer server = new TransientFaultTdsServer( + new TransientFaultTdsServerArguments() + { + IsEnabledTransientError = true, + Number = errorCode, + FailoverPartner = $"localhost:{failoverServer.EndPoint.Port}", + }); + server.Start(); + + SqlConnectionStringBuilder builder = new SqlConnectionStringBuilder() + { + DataSource = $"localhost,{server.EndPoint.Port}", + InitialCatalog = "master", + ConnectTimeout = 30, + ConnectRetryInterval = 1, + Encrypt = false + }; + using SqlConnection connection = new(builder.ConnectionString); + try + { + // Act + connection.Open(); + } + catch (Exception e) + { + Assert.Fail(e.Message); + } + + // Assert + Assert.Equal(ConnectionState.Open, connection.State); + Assert.Equal($"localhost,{server.EndPoint.Port}", connection.DataSource); + + // Failures should prompt the client to return to the original server, resulting in a login count of 2 + Assert.Equal(2, server.PreLoginCount); + } + + [Theory] + [InlineData(40613)] + [InlineData(42108)] + [InlineData(42109)] + public void TransientFault_RetryDisabled_ShouldFail(uint errorCode) + { + // Arrange + using TdsServer failoverServer = new TdsServer( + new TdsServerArguments + { + // Doesn't need to point to a real endpoint, just needs a value specified + FailoverPartner = "localhost:1234", + }); + failoverServer.Start(); + + using TransientFaultTdsServer server = new TransientFaultTdsServer( + new TransientFaultTdsServerArguments() + { + IsEnabledTransientError = true, + Number = errorCode, + FailoverPartner = $"localhost:{failoverServer.EndPoint.Port}", + }); + server.Start(); + + SqlConnectionStringBuilder builder = new SqlConnectionStringBuilder() + { + DataSource = $"localhost,{server.EndPoint.Port}", + InitialCatalog = "master", + ConnectTimeout = 30, + ConnectRetryInterval = 1, + ConnectRetryCount = 0, // Disable retry + Encrypt = false + }; + using SqlConnection connection = new(builder.ConnectionString); + try + { + // Act + connection.Open(); + } + catch (SqlException e) + { + Assert.Equal((int)errorCode, e.Number); + return; + } + + Assert.Fail(); + } + + [Theory] + [InlineData(40613)] + [InlineData(42108)] + [InlineData(42109)] + public void TransientFault_WithUserProvidedPartner_ShouldConnectToPrimary(uint errorCode) + { + // Arrange + using TdsServer failoverServer = new TdsServer( + new TdsServerArguments + { + // Doesn't need to point to a real endpoint, just needs a value specified + FailoverPartner = "localhost:1234", + }); + failoverServer.Start(); + + using TransientFaultTdsServer server = new TransientFaultTdsServer( + new TransientFaultTdsServerArguments() + { + IsEnabledTransientError = true, + Number = errorCode, + FailoverPartner = $"localhost:{failoverServer.EndPoint.Port}", + }); + server.Start(); + + SqlConnectionStringBuilder builder = new SqlConnectionStringBuilder() + { + DataSource = $"localhost,{server.EndPoint.Port}", + InitialCatalog = "master", + ConnectTimeout = 30, + ConnectRetryInterval = 1, + Encrypt = false, + FailoverPartner = $"localhost:{failoverServer.EndPoint.Port}", // User provided failover partner + }; + using SqlConnection connection = new(builder.ConnectionString); + try + { + // Act + connection.Open(); + } + catch (Exception e) + { + Assert.Fail(e.Message); + } + + // Assert + Assert.Equal(ConnectionState.Open, connection.State); + Assert.Equal($"localhost,{server.EndPoint.Port}", connection.DataSource); + + // Failures should prompt the client to return to the original server, resulting in a login count of 2 + Assert.Equal(2, server.PreLoginCount); + } + + [Theory] + [InlineData(40613)] + [InlineData(42108)] + [InlineData(42109)] + public void TransientFault_WithUserProvidedPartner_RetryDisabled_ShouldFail(uint errorCode) + { + // Arrange + using TdsServer failoverServer = new TdsServer( + new TdsServerArguments + { + // Doesn't need to point to a real endpoint, just needs a value specified + FailoverPartner = "localhost:1234", + }); + failoverServer.Start(); + + using TransientFaultTdsServer server = new TransientFaultTdsServer( + new TransientFaultTdsServerArguments() + { + IsEnabledTransientError = true, + Number = errorCode, + FailoverPartner = $"localhost:{failoverServer.EndPoint.Port}", + }); + server.Start(); + + SqlConnectionStringBuilder builder = new SqlConnectionStringBuilder() + { + DataSource = $"localhost,{server.EndPoint.Port}", + InitialCatalog = "master", + ConnectTimeout = 30, + ConnectRetryInterval = 1, + ConnectRetryCount = 0, // Disable retry + Encrypt = false, + FailoverPartner = $"localhost:{failoverServer.EndPoint.Port}", // User provided failover partner + }; + using SqlConnection connection = new(builder.ConnectionString); + try + { + // Act + connection.Open(); + } + catch (SqlException e) + { + Assert.Equal((int)errorCode, e.Number); + return; + } + + Assert.Fail(); + } + } +} diff --git a/src/Microsoft.Data.SqlClient/tests/UnitTests/SimulatedServerTests/ConnectionReadOnlyRoutingTests.cs b/src/Microsoft.Data.SqlClient/tests/UnitTests/SimulatedServerTests/ConnectionReadOnlyRoutingTests.cs new file mode 100644 index 0000000000..324a58213b --- /dev/null +++ b/src/Microsoft.Data.SqlClient/tests/UnitTests/SimulatedServerTests/ConnectionReadOnlyRoutingTests.cs @@ -0,0 +1,157 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using System.Collections.Generic; +using System.Net; +using System.Threading.Tasks; +using Microsoft.SqlServer.TDS.Servers; +using Xunit; + +namespace Microsoft.Data.SqlClient.ScenarioTests +{ + public class ConnectionReadOnlyRoutingTests + { + [Fact] + public void NonRoutedConnection() + { + using TdsServer server = new TdsServer(); + server.Start(); + SqlConnectionStringBuilder builder = new SqlConnectionStringBuilder() { + DataSource = $"localhost,{server.EndPoint.Port}", + ApplicationIntent = ApplicationIntent.ReadOnly, + Encrypt = SqlConnectionEncryptOption.Optional + }; + using SqlConnection connection = new SqlConnection(builder.ConnectionString); + connection.Open(); + } + + [Fact] + public async Task NonRoutedAsyncConnection() + { + using TdsServer server = new TdsServer(); + server.Start(); + SqlConnectionStringBuilder builder = new SqlConnectionStringBuilder() { + DataSource = $"localhost,{server.EndPoint.Port}", + ApplicationIntent = ApplicationIntent.ReadOnly, + Encrypt = SqlConnectionEncryptOption.Optional + }; + using SqlConnection connection = new SqlConnection(builder.ConnectionString); + await connection.OpenAsync(); + } + + [Fact] + public void RoutedConnection() + => RecursivelyRoutedConnection(1); + + [Fact] + public async Task RoutedAsyncConnection() + => await RecursivelyRoutedAsyncConnection(1); + + [Theory] + [InlineData(11)] // 11 layers of routing should succeed, 12 should fail + public void RecursivelyRoutedConnection(int layers) + { + using TdsServer innerServer = new TdsServer(); + innerServer.Start(); + IPEndPoint lastEndpoint = innerServer.EndPoint; + Stack routingLayers = new(layers + 1); + string lastConnectionString = (new SqlConnectionStringBuilder() { DataSource = $"localhost,{lastEndpoint.Port}" }).ConnectionString; + + try + { + for (int i = 0; i < layers; i++) + { + RoutingTdsServer router = new RoutingTdsServer( + new RoutingTdsServerArguments() + { + RoutingTCPHost = "localhost", + RoutingTCPPort = (ushort)lastEndpoint.Port, + }); + router.Start(); + routingLayers.Push(router); + lastEndpoint = router.EndPoint; + lastConnectionString = (new SqlConnectionStringBuilder() { + DataSource = $"localhost,{lastEndpoint.Port}", + ApplicationIntent = ApplicationIntent.ReadOnly, + Encrypt = false + }).ConnectionString; + } + + SqlConnectionStringBuilder builder = new SqlConnectionStringBuilder(lastConnectionString) { ApplicationIntent = ApplicationIntent.ReadOnly }; + using SqlConnection connection = new SqlConnection(builder.ConnectionString); + connection.Open(); + } + finally + { + while (routingLayers.Count > 0) + { + routingLayers.Pop().Dispose(); + } + } + } + + [Theory] + [InlineData(11)] // 11 layers of routing should succeed, 12 should fail + public async Task RecursivelyRoutedAsyncConnection(int layers) + { + using TdsServer innerServer = new TdsServer(); + innerServer.Start(); + IPEndPoint lastEndpoint = innerServer.EndPoint; + Stack routingLayers = new(layers + 1); + string lastConnectionString = (new SqlConnectionStringBuilder() { DataSource = $"localhost,{lastEndpoint.Port}" }).ConnectionString; + + try + { + for (int i = 0; i < layers; i++) + { + RoutingTdsServer router = new RoutingTdsServer( + new RoutingTdsServerArguments() + { + RoutingTCPHost = "localhost", + RoutingTCPPort = (ushort)lastEndpoint.Port, + }); + router.Start(); + routingLayers.Push(router); + lastEndpoint = router.EndPoint; + lastConnectionString = (new SqlConnectionStringBuilder() { + DataSource = $"localhost,{lastEndpoint.Port}", + ApplicationIntent = ApplicationIntent.ReadOnly, + Encrypt = false + }).ConnectionString; + } + + SqlConnectionStringBuilder builder = new SqlConnectionStringBuilder(lastConnectionString) { + ApplicationIntent = ApplicationIntent.ReadOnly, + Encrypt = false + }; + using SqlConnection connection = new SqlConnection(builder.ConnectionString); + await connection.OpenAsync(); + } + finally + { + while (routingLayers.Count > 0) + { + routingLayers.Pop().Dispose(); + } + } + } + + [Fact] + public void ConnectionRoutingLimit() + { + SqlException sqlEx = Assert.Throws(() => RecursivelyRoutedConnection(12)); // This will fail on the 11th redirect + + Assert.Contains("Too many redirections have occurred.", sqlEx.Message, StringComparison.InvariantCultureIgnoreCase); + } + + [Fact] + public async Task AsyncConnectionRoutingLimit() + { + SqlException sqlEx = await Assert.ThrowsAsync(() => RecursivelyRoutedAsyncConnection(12)); // This will fail on the 11th redirect + + Assert.Contains("Too many redirections have occurred.", sqlEx.Message, StringComparison.InvariantCultureIgnoreCase); + } + } +} diff --git a/src/Microsoft.Data.SqlClient/tests/UnitTests/SimulatedServerTests/ConnectionRoutingTests.cs b/src/Microsoft.Data.SqlClient/tests/UnitTests/SimulatedServerTests/ConnectionRoutingTests.cs new file mode 100644 index 0000000000..858095a0ca --- /dev/null +++ b/src/Microsoft.Data.SqlClient/tests/UnitTests/SimulatedServerTests/ConnectionRoutingTests.cs @@ -0,0 +1,265 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using System.Data; +using Microsoft.SqlServer.TDS.Servers; +using Xunit; + +namespace Microsoft.Data.SqlClient.ScenarioTests +{ + public class ConnectionRoutingTests + { + [Theory] + [InlineData(40613)] + [InlineData(42108)] + [InlineData(42109)] + public void TransientFaultAtRoutedLocation_ShouldReturnToGateway(uint errorCode) + { + // Arrange + using TransientFaultTdsServer server = new TransientFaultTdsServer( + new TransientFaultTdsServerArguments() + { + IsEnabledTransientError = true, + Number = errorCode, + }); + + server.Start(); + + using RoutingTdsServer router = new RoutingTdsServer( + new RoutingTdsServerArguments() + { + //RoutingTCPHost = server.EndPoint.Address.ToString() == IPAddress.Any.ToString() ? IPAddress.Loopback.ToString() : server.EndPoint.Address.ToString(), + RoutingTCPHost = "localhost", + RoutingTCPPort = (ushort)server.EndPoint.Port, + }); + router.Start(); + + SqlConnectionStringBuilder builder = new SqlConnectionStringBuilder() + { + DataSource = "localhost," + router.EndPoint.Port, + ApplicationIntent = ApplicationIntent.ReadOnly, + ConnectTimeout = 30, + ConnectRetryInterval = 1, + Encrypt = false, + }; + using SqlConnection connection = new(builder.ConnectionString); + try + { + // Act + connection.Open(); + } + catch (Exception e) + { + Assert.Fail(e.Message); + } + + // Assert + Assert.Equal(ConnectionState.Open, connection.State); + // Routing does not update the connection's data source + Assert.Equal($"localhost,{router.EndPoint.Port}", connection.DataSource); + + // Failures should prompt the client to return to the original server, resulting in a login count of 2 + Assert.Equal(2, router.PreLoginCount); + Assert.Equal(2, server.PreLoginCount); + } + + [Theory] + [InlineData(40613)] + [InlineData(42108)] + [InlineData(42109)] + public void TransientFaultAtRoutedLocation_RetryDisabled_ShouldFail(uint errorCode) + { + // Arrange + using TransientFaultTdsServer server = new TransientFaultTdsServer( + new TransientFaultTdsServerArguments() + { + IsEnabledTransientError = true, + Number = errorCode, + }); + + server.Start(); + + using RoutingTdsServer router = new RoutingTdsServer( + new RoutingTdsServerArguments() + { + //RoutingTCPHost = server.EndPoint.Address.ToString() == IPAddress.Any.ToString() ? IPAddress.Loopback.ToString() : server.EndPoint.Address.ToString(), + RoutingTCPHost = "localhost", + RoutingTCPPort = (ushort)server.EndPoint.Port, + }); + router.Start(); + + SqlConnectionStringBuilder builder = new SqlConnectionStringBuilder() + { + DataSource = "localhost," + router.EndPoint.Port, + ApplicationIntent = ApplicationIntent.ReadOnly, + ConnectTimeout = 30, + ConnectRetryInterval = 1, + ConnectRetryCount = 0, // Disable retry + Encrypt = false, + }; + using SqlConnection connection = new(builder.ConnectionString); + //TODO validate exception type + Assert.Throws(() => connection.Open()); + } + + [ActiveIssue("https://github.com/dotnet/SqlClient/issues/3528")] + [ActiveIssue("https://github.com/dotnet/SqlClient/issues/3527")] + [Fact] + public void NetworkErrorAtRoutedLocation_ShouldReturnToGateway() + { + // Arrange + using TransientDelayTdsServer server = new TransientDelayTdsServer( + new TransientDelayTdsServerArguments() + { + IsEnabledTransientTimeout = true, + SleepDuration = TimeSpan.FromMilliseconds(1000), + }); + + server.Start(); + + using RoutingTdsServer router = new RoutingTdsServer( + new RoutingTdsServerArguments() + { + RoutingTCPHost = "localhost", + RoutingTCPPort = (ushort)server.EndPoint.Port, + }); + router.Start(); + + SqlConnectionStringBuilder builder = new SqlConnectionStringBuilder() + { + DataSource = "localhost," + router.EndPoint.Port, + ApplicationIntent = ApplicationIntent.ReadOnly, + ConnectTimeout = 5, + ConnectRetryInterval = 1, + Encrypt = false, + }; + using SqlConnection connection = new(builder.ConnectionString); + try + { + // Act + connection.Open(); + } + catch (Exception e) + { + Assert.Fail(e.Message); + } + + // Assert + Assert.Equal(ConnectionState.Open, connection.State); + // Routing does not update the connection's data source + Assert.Equal($"localhost,{router.EndPoint.Port}", connection.DataSource); + + // Failures should prompt the client to return to the original server, resulting in a login count of 2 + Assert.Equal(2, router.PreLoginCount); + Assert.Equal(2, server.PreLoginCount); + } + + [ActiveIssue("https://github.com/dotnet/SqlClient/issues/3528")] + [ActiveIssue("https://github.com/dotnet/SqlClient/issues/3527")] + [Fact] + public void NetworkErrorAtRoutedLocation_RetryDisabled_ShouldFail() + { + // Arrange + using TransientDelayTdsServer server = new TransientDelayTdsServer( + new TransientDelayTdsServerArguments() + { + IsEnabledTransientTimeout = true, + SleepDuration = TimeSpan.FromMilliseconds(1000), + }); + + server.Start(); + + using RoutingTdsServer router = new RoutingTdsServer( + new RoutingTdsServerArguments() + { + RoutingTCPHost = "localhost", + RoutingTCPPort = (ushort)server.EndPoint.Port, + }); + router.Start(); + + SqlConnectionStringBuilder builder = new SqlConnectionStringBuilder() + { + DataSource = "localhost," + router.EndPoint.Port, + ApplicationIntent = ApplicationIntent.ReadOnly, + ConnectTimeout = 5, + ConnectRetryInterval = 1, + ConnectRetryCount = 0, // disable retry + Encrypt = false, + }; + using SqlConnection connection = new(builder.ConnectionString); + //TODO validate exception type + Assert.Throws(() => connection.Open()); + } + + [ActiveIssue("https://github.com/dotnet/SqlClient/issues/3528")] + [ActiveIssue("https://github.com/dotnet/SqlClient/issues/3527")] + [Fact] + public void NetworkErrorDuringCommand_ShouldReturnToGateway() + { + // Arrange + using TransientDelayTdsServer server = new TransientDelayTdsServer( + new TransientDelayTdsServerArguments() + { + IsEnabledTransientTimeout = false, + SleepDuration = TimeSpan.FromMilliseconds(1000), + }); + + server.Start(); + + using RoutingTdsServer router = new RoutingTdsServer( + new RoutingTdsServerArguments() + { + RoutingTCPHost = "localhost", + RoutingTCPPort = (ushort)server.EndPoint.Port, + }); + router.Start(); + + SqlConnectionStringBuilder builder = new SqlConnectionStringBuilder() + { + DataSource = "localhost," + router.EndPoint.Port, + ApplicationIntent = ApplicationIntent.ReadOnly, + ConnectTimeout = 5, + ConnectRetryInterval = 1, + Encrypt = false, + CommandTimeout = 5, + ConnectRetryCount = 1 + }; + using SqlConnection connection = new(builder.ConnectionString); + try + { + // Act + connection.Open(); + } + catch (Exception e) + { + Assert.Fail(e.Message); + } + + // Assert + Assert.Equal(ConnectionState.Open, connection.State); + // Routing does not update the connection's data source + Assert.Equal($"localhost,{router.EndPoint.Port}", connection.DataSource); + + Assert.Equal(1, router.PreLoginCount); + Assert.Equal(1, server.PreLoginCount); + + // Break the connection to force a reconnect + server.KillAllConnections(); + + server.SetTransientTimeoutBehavior(true, TimeSpan.FromMilliseconds(1000)); + + SqlCommand command = new SqlCommand("Select 1;", connection); + command.ExecuteScalar(); + + // Assert + Assert.Equal(ConnectionState.Open, connection.State); + Assert.Equal($"localhost,{router.EndPoint.Port}", connection.DataSource); + + // Failures should prompt the client to return to the gateway + Assert.Equal(3, router.PreLoginCount); + Assert.Equal(3, server.PreLoginCount); + } + } +} diff --git a/src/Microsoft.Data.SqlClient/tests/FunctionalTests/SqlConnectionBasicTests.cs b/src/Microsoft.Data.SqlClient/tests/UnitTests/SimulatedServerTests/ConnectionTests.cs similarity index 65% rename from src/Microsoft.Data.SqlClient/tests/FunctionalTests/SqlConnectionBasicTests.cs rename to src/Microsoft.Data.SqlClient/tests/UnitTests/SimulatedServerTests/ConnectionTests.cs index 616a8fec6f..42fa84b47b 100644 --- a/src/Microsoft.Data.SqlClient/tests/FunctionalTests/SqlConnectionBasicTests.cs +++ b/src/Microsoft.Data.SqlClient/tests/UnitTests/SimulatedServerTests/ConnectionTests.cs @@ -9,7 +9,6 @@ using System.Globalization; using System.Linq; using System.Reflection; -using System.Runtime.InteropServices; using System.Security; using System.Threading; using System.Threading.Tasks; @@ -20,24 +19,34 @@ using Microsoft.SqlServer.TDS.Servers; using Xunit; -namespace Microsoft.Data.SqlClient.Tests +namespace Microsoft.Data.SqlClient.ScenarioTests { - public class SqlConnectionBasicTests + public class ConnectionTests { [Fact] public void ConnectionTest() { - using TestTdsServer server = TestTdsServer.StartTestServer(); - using SqlConnection connection = new SqlConnection(server.ConnectionString); + using TdsServer server = new TdsServer(new TdsServerArguments() { }); + server.Start(); + var connStr = new SqlConnectionStringBuilder() { + DataSource = $"localhost,{server.EndPoint.Port}", + Encrypt = SqlConnectionEncryptOption.Optional, + }.ConnectionString; + using SqlConnection connection = new SqlConnection(connStr); connection.Open(); } - [ConditionalFact(typeof(TestUtility), nameof(TestUtility.IsNotArmProcess))] + [Fact] [PlatformSpecific(TestPlatforms.Windows)] public void IntegratedAuthConnectionTest() { - using TestTdsServer server = TestTdsServer.StartTestServer(); - SqlConnectionStringBuilder builder = new SqlConnectionStringBuilder(server.ConnectionString); + using TdsServer server = new TdsServer(new TdsServerArguments() { }); + server.Start(); + var connStr = new SqlConnectionStringBuilder() { + DataSource = $"localhost,{server.EndPoint.Port}", + Encrypt = SqlConnectionEncryptOption.Optional, + }.ConnectionString; + SqlConnectionStringBuilder builder = new SqlConnectionStringBuilder(connStr); builder.IntegratedSecurity = true; using SqlConnection connection = new SqlConnection(builder.ConnectionString); connection.Open(); @@ -49,107 +58,242 @@ public void IntegratedAuthConnectionTest() /// when client enables encryption using Encrypt=true or uses default encryption setting. /// [Fact] - public async Task PreLoginEncryptionExcludedTest() + public async Task RequestEncryption_ServerDoesNotSupportEncryption_ShouldFail() { - using TestTdsServer server = TestTdsServer.StartTestServer(false, false, 5, excludeEncryption: true); - SqlConnectionStringBuilder builder = new(server.ConnectionString) - { - IntegratedSecurity = true - }; + using TdsServer server = new TdsServer(new TdsServerArguments() {Encryption = TDSPreLoginTokenEncryptionType.None }); + server.Start(); + var connStr = new SqlConnectionStringBuilder() { + DataSource = $"localhost,{server.EndPoint.Port}" + }.ConnectionString; - using SqlConnection connection = new(builder.ConnectionString); + using SqlConnection connection = new(connStr); Exception ex = await Assert.ThrowsAsync(async () => await connection.OpenAsync()); Assert.Contains("The instance of SQL Server you attempted to connect to does not support encryption.", ex.Message, StringComparison.OrdinalIgnoreCase); } - [ConditionalTheory(typeof(TestUtility), nameof(TestUtility.IsNotArmProcess))] + [Theory] [InlineData(40613)] [InlineData(42108)] [InlineData(42109)] - [PlatformSpecific(TestPlatforms.Windows)] - public async Task TransientFaultTestAsync(uint errorCode) + public async Task TransientFault_RetryEnabled_ShouldSucceed_Async(uint errorCode) { - using TransientFaultTDSServer server = TransientFaultTDSServer.StartTestServer(true, false, errorCode); + using TransientFaultTdsServer server = new TransientFaultTdsServer( + new TransientFaultTdsServerArguments() + { + IsEnabledTransientError = true, + Number = errorCode, + }); + server.Start(); SqlConnectionStringBuilder builder = new() { - DataSource = "localhost," + server.Port, - IntegratedSecurity = true, + DataSource = "localhost," + server.EndPoint.Port, Encrypt = SqlConnectionEncryptOption.Optional }; using SqlConnection connection = new(builder.ConnectionString); await connection.OpenAsync(); Assert.Equal(ConnectionState.Open, connection.State); + Assert.Equal($"localhost,{server.EndPoint.Port}", connection.DataSource); } - [ConditionalTheory(typeof(TestUtility), nameof(TestUtility.IsNotArmProcess))] + [Theory] [InlineData(40613)] [InlineData(42108)] [InlineData(42109)] - [PlatformSpecific(TestPlatforms.Windows)] - public void TransientFaultTest(uint errorCode) + public void TransientFault_RetryEnabled_ShouldSucceed(uint errorCode) { - using TransientFaultTDSServer server = TransientFaultTDSServer.StartTestServer(true, false, errorCode); + using TransientFaultTdsServer server = new TransientFaultTdsServer( + new TransientFaultTdsServerArguments() + { + IsEnabledTransientError = true, + Number = errorCode, + }); + server.Start(); SqlConnectionStringBuilder builder = new() { - DataSource = "localhost," + server.Port, - IntegratedSecurity = true, + DataSource = "localhost," + server.EndPoint.Port, Encrypt = SqlConnectionEncryptOption.Optional }; using SqlConnection connection = new(builder.ConnectionString); - try - { - connection.Open(); - Assert.Equal(ConnectionState.Open, connection.State); - } - catch (Exception e) - { - Assert.Fail(e.Message); - } + connection.Open(); + Assert.Equal(ConnectionState.Open, connection.State); + Assert.Equal($"localhost,{server.EndPoint.Port}", connection.DataSource); } - [ConditionalTheory(typeof(TestUtility), nameof(TestUtility.IsNotArmProcess))] + [Theory] [InlineData(40613)] [InlineData(42108)] [InlineData(42109)] - [PlatformSpecific(TestPlatforms.Windows)] - public void TransientFaultDisabledTestAsync(uint errorCode) + public async Task TransientFault_RetryDisabled_ShouldFail_Async(uint errorCode) { - using TransientFaultTDSServer server = TransientFaultTDSServer.StartTestServer(true, false, errorCode); + using TransientFaultTdsServer server = new TransientFaultTdsServer( + new TransientFaultTdsServerArguments() + { + IsEnabledTransientError = true, + Number = errorCode, + }); + server.Start(); SqlConnectionStringBuilder builder = new() { - DataSource = "localhost," + server.Port, - IntegratedSecurity = true, + DataSource = "localhost," + server.EndPoint.Port, ConnectRetryCount = 0, Encrypt = SqlConnectionEncryptOption.Optional }; using SqlConnection connection = new(builder.ConnectionString); - Task e = Assert.ThrowsAsync(async () => await connection.OpenAsync()); - Assert.Equal(20, e.Result.Class); + SqlException e = await Assert.ThrowsAsync(async () => await connection.OpenAsync()); + Assert.Equal((int)errorCode, e.Number); Assert.Equal(ConnectionState.Closed, connection.State); } - [ConditionalTheory(typeof(TestUtility), nameof(TestUtility.IsNotArmProcess))] + [Theory] [InlineData(40613)] [InlineData(42108)] [InlineData(42109)] - [PlatformSpecific(TestPlatforms.Windows)] - public void TransientFaultDisabledTest(uint errorCode) + public void TransientFault_RetryDisabled_ShouldFail(uint errorCode) + { + using TransientFaultTdsServer server = new TransientFaultTdsServer( + new TransientFaultTdsServerArguments() + { + IsEnabledTransientError = true, + Number = errorCode, + }); + server.Start(); + SqlConnectionStringBuilder builder = new() + { + DataSource = "localhost," + server.EndPoint.Port, + ConnectRetryCount = 0, + Encrypt = SqlConnectionEncryptOption.Optional + }; + + using SqlConnection connection = new(builder.ConnectionString); + SqlException e = Assert.Throws(() => connection.Open()); + Assert.Equal((int)errorCode, e.Number); + Assert.Equal(ConnectionState.Closed, connection.State); + } + + [Theory] + [InlineData(false)] + [InlineData(true)] + public async Task NetworkError_RetryEnabled_ShouldSucceed_Async(bool multiSubnetFailoverEnabled) + { + using TransientDelayTdsServer server = new TransientDelayTdsServer( + new TransientDelayTdsServerArguments() + { + IsEnabledTransientTimeout = true, + SleepDuration = TimeSpan.FromMilliseconds(1000), + }); + server.Start(); + SqlConnectionStringBuilder builder = new() + { + DataSource = "localhost," + server.EndPoint.Port, + Encrypt = SqlConnectionEncryptOption.Optional, + ConnectTimeout = 5, + MultiSubnetFailover = multiSubnetFailoverEnabled, +#if NETFRAMEWORK + TransparentNetworkIPResolution = multiSubnetFailoverEnabled +#endif + }; + + using SqlConnection connection = new(builder.ConnectionString); + await connection.OpenAsync(); + Assert.Equal(ConnectionState.Open, connection.State); + Assert.Equal($"localhost,{server.EndPoint.Port}", connection.DataSource); + if (multiSubnetFailoverEnabled) + { + Assert.True(server.PreLoginCount > 1, "Expected multiple pre-login attempts due to retry."); + } + else + { + Assert.Equal(1, server.PreLoginCount); + } + } + + [Theory] + [InlineData(false)] + [InlineData(true)] + public void NetworkError_RetryEnabled_ShouldSucceed(bool multiSubnetFailoverEnabled) { - using TransientFaultTDSServer server = TransientFaultTDSServer.StartTestServer(true, false, errorCode); + using TransientDelayTdsServer server = new TransientDelayTdsServer( + new TransientDelayTdsServerArguments() + { + IsEnabledTransientTimeout = true, + SleepDuration = TimeSpan.FromMilliseconds(3000), + }); + server.Start(); + SqlConnectionStringBuilder builder = new() + { + DataSource = "localhost," + server.EndPoint.Port, + Encrypt = SqlConnectionEncryptOption.Optional, + ConnectTimeout = 5, + MultiSubnetFailover = multiSubnetFailoverEnabled, +#if NETFRAMEWORK + TransparentNetworkIPResolution = multiSubnetFailoverEnabled +#endif + }; + + using SqlConnection connection = new(builder.ConnectionString); + connection.Open(); + + Assert.Equal(ConnectionState.Open, connection.State); + Assert.Equal($"localhost,{server.EndPoint.Port}", connection.DataSource); + if (multiSubnetFailoverEnabled) + { + Assert.True(server.PreLoginCount > 1, "Expected multiple pre-login attempts due to retry."); + } + else + { + Assert.Equal(1, server.PreLoginCount); + } + } + + [ActiveIssue("https://github.com/dotnet/SqlClient/issues/3527")] + [Fact] + public async Task NetworkError_RetryDisabled_ShouldFail_Async() + { + using TransientDelayTdsServer server = new TransientDelayTdsServer( + new TransientDelayTdsServerArguments() + { + IsEnabledTransientTimeout = true, + SleepDuration = TimeSpan.FromMilliseconds(1000), + }); + server.Start(); SqlConnectionStringBuilder builder = new() { - DataSource = "localhost," + server.Port, - IntegratedSecurity = true, + DataSource = "localhost," + server.EndPoint.Port, ConnectRetryCount = 0, Encrypt = SqlConnectionEncryptOption.Optional }; + using SqlConnection connection = new(builder.ConnectionString); + SqlException e = await Assert.ThrowsAsync(async () => await connection.OpenAsync()); + Assert.Contains("Connection Timeout Expired", e.Message); + Assert.Equal(ConnectionState.Closed, connection.State); + } + + [ActiveIssue("https://github.com/dotnet/SqlClient/issues/3527")] + [Fact] + public void NetworkError_RetryDisabled_ShouldFail() + { + using TransientDelayTdsServer server = new TransientDelayTdsServer( + new TransientDelayTdsServerArguments() + { + IsEnabledTransientTimeout = true, + SleepDuration = TimeSpan.FromMilliseconds(1000), + }); + server.Start(); + SqlConnectionStringBuilder builder = new() + { + DataSource = "localhost," + server.EndPoint.Port, + ConnectRetryCount = 0, + Encrypt = SqlConnectionEncryptOption.Optional, + ConnectTimeout = 5 + }; + using SqlConnection connection = new(builder.ConnectionString); SqlException e = Assert.Throws(() => connection.Open()); - Assert.Equal(20, e.Class); + Assert.Contains("Connection Timeout Expired", e.Message); Assert.Equal(ConnectionState.Closed, connection.State); } @@ -302,16 +446,20 @@ public void ConnectionTestValidCredentialCombination() [Theory] [InlineData(60)] - [InlineData(30)] - [InlineData(15)] [InlineData(10)] - [InlineData(5)] [InlineData(1)] public void ConnectionTimeoutTest(int timeout) { // Start a server with connection timeout from the inline data. - using TestTdsServer server = TestTdsServer.StartTestServer(false, false, timeout); - using SqlConnection connection = new SqlConnection(server.ConnectionString); + //TODO: do we even need a server for this test? + using TdsServer server = new TdsServer(); + server.Start(); + var connStr = new SqlConnectionStringBuilder() { + DataSource = $"localhost,{server.EndPoint.Port}", + ConnectTimeout = timeout, + Encrypt = SqlConnectionEncryptOption.Optional + }.ConnectionString; + using SqlConnection connection = new SqlConnection(connStr); // Dispose the server to force connection timeout server.Dispose(); @@ -341,16 +489,21 @@ public void ConnectionTimeoutTest(int timeout) [Theory] [InlineData(60)] - [InlineData(30)] - [InlineData(15)] [InlineData(10)] - [InlineData(5)] [InlineData(1)] public async Task ConnectionTimeoutTestAsync(int timeout) { // Start a server with connection timeout from the inline data. - using TestTdsServer server = TestTdsServer.StartTestServer(false, false, timeout); - using SqlConnection connection = new SqlConnection(server.ConnectionString); + //TODO: do we even need a server for this test? + using TdsServer server = new TdsServer(); + server.Start(); + var connStr = new SqlConnectionStringBuilder() + { + DataSource = $"localhost,{server.EndPoint.Port}", + ConnectTimeout = timeout, + Encrypt = SqlConnectionEncryptOption.Optional + }.ConnectionString; + using SqlConnection connection = new SqlConnection(connStr); // Dispose the server to force connection timeout server.Dispose(); @@ -385,7 +538,11 @@ public void ConnectionInvalidTimeoutTest() { Assert.Throws(() => { - using TestTdsServer server = TestTdsServer.StartTestServer(false, false, -5); + var connectionString = new SqlConnectionStringBuilder() + { + DataSource = "localhost", + ConnectTimeout = -5 // Invalid timeout + }.ConnectionString; }); } @@ -401,8 +558,15 @@ public void ConnectionTestWithCultureTH() Thread.CurrentThread.CurrentCulture = new CultureInfo("th-TH"); Thread.CurrentThread.CurrentUICulture = new CultureInfo("th-TH"); - using TestTdsServer server = TestTdsServer.StartTestServer(); - using SqlConnection connection = new SqlConnection(server.ConnectionString); + //TODO: do we even need a server for this test? + using TdsServer server = new TdsServer(); + server.Start(); + var connStr = new SqlConnectionStringBuilder() + { + DataSource = $"localhost,{server.EndPoint.Port}", + Encrypt = SqlConnectionEncryptOption.Optional + }.ConnectionString; + using SqlConnection connection = new SqlConnection(connStr); connection.Open(); Assert.Equal(ConnectionState.Open, connection.State); } @@ -505,8 +669,19 @@ public void ConnectionTestAccessTokenCallbackCombinations() public void ConnectionTestPermittedVersion(int major, int minor, int build) { Version simulatedServerVersion = new Version(major, minor, build); - using TestTdsServer server = TestTdsServer.StartTestServer(serverVersion: simulatedServerVersion); - using SqlConnection conn = new SqlConnection(server.ConnectionString); + + using TdsServer server = new TdsServer( + new TdsServerArguments + { + ServerVersion = simulatedServerVersion, + }); + server.Start(); + var connStr = new SqlConnectionStringBuilder() + { + DataSource = $"localhost,{server.EndPoint.Port}", + Encrypt = SqlConnectionEncryptOption.Optional, + }.ConnectionString; + using SqlConnection conn = new SqlConnection(connStr); conn.Open(); Assert.Equal(ConnectionState.Open, conn.State); @@ -523,8 +698,18 @@ public void ConnectionTestPermittedVersion(int major, int minor, int build) public void ConnectionTestDeniedVersion(int major, int minor, int build) { Version simulatedServerVersion = new Version(major, minor, build); - using TestTdsServer server = TestTdsServer.StartTestServer(serverVersion: simulatedServerVersion); - using SqlConnection conn = new SqlConnection(server.ConnectionString); + using TdsServer server = new TdsServer( + new TdsServerArguments + { + ServerVersion = simulatedServerVersion, + }); + server.Start(); + var connStr = new SqlConnectionStringBuilder() + { + DataSource = $"localhost,{server.EndPoint.Port}", + Encrypt = SqlConnectionEncryptOption.Optional, + }.ConnectionString; + using SqlConnection conn = new SqlConnection(connStr); Assert.Throws(() => conn.Open()); } @@ -542,7 +727,8 @@ public void ConnectionTestDeniedVersion(int major, int minor, int build) public void TestConnWithVectorFeatExtVersionNegotiation(bool expectedConnectionResult, byte serverVersion, byte expectedNegotiatedVersion) { // Start the test TDS server. - using var server = TestTdsServer.StartTestServer(); + using var server = new TdsServer(); + server.Start(); server.ServerSupportedVectorFeatureExtVersion = serverVersion; server.EnableVectorFeatureExt = serverVersion == 0xFF ? false : true; @@ -594,7 +780,12 @@ public void TestConnWithVectorFeatExtVersionNegotiation(bool expectedConnectionR }; // Connect to the test TDS server. - using var connection = new SqlConnection(server.ConnectionString); + var connStr = new SqlConnectionStringBuilder + { + DataSource = $"localhost,{server.EndPoint.Port}", + Encrypt = SqlConnectionEncryptOption.Optional, + }.ConnectionString; + using var connection = new SqlConnection(connStr); if (expectedConnectionResult) { connection.Open(); diff --git a/src/Microsoft.Data.SqlClient/tests/tools/TDS/TDS.EndPoint/TDSServerEndPoint.cs b/src/Microsoft.Data.SqlClient/tests/tools/TDS/TDS.EndPoint/TDSServerEndPoint.cs index e81139c63a..ac81691d40 100644 --- a/src/Microsoft.Data.SqlClient/tests/tools/TDS/TDS.EndPoint/TDSServerEndPoint.cs +++ b/src/Microsoft.Data.SqlClient/tests/tools/TDS/TDS.EndPoint/TDSServerEndPoint.cs @@ -30,7 +30,8 @@ public override TDSServerEndPointConnection CreateConnection(TcpClient newConnec /// /// General server handler /// - public abstract class ServerEndPointHandler where T : ServerEndPointConnection + public abstract class ServerEndPointHandler : IDisposable + where T : ServerEndPointConnection { /// /// Gets/Sets the event log for the proxy server @@ -131,25 +132,7 @@ public void Stop() // Request the listener thread to stop StopRequested = true; - // A copy of the list of connections to avoid locking - IList unlockedConnections = new List(); - - // Synchronize access to connections collection - lock (Connections) - { - // Iterate over all connections and copy into the local list - foreach (T connection in Connections) - { - unlockedConnections.Add(connection); - } - } - - // Iterate over all connections and request each one to stop - foreach (T connection in unlockedConnections) - { - // Request to stop - connection.Stop(); - } + KillAllConnections(); // If server failed to start there is no thread to join if (ListenerThread != null) @@ -167,6 +150,28 @@ public void Stop() } } + public void KillAllConnections() + { + // Synchronize access to connections collection + lock (Connections) + { + // Iterate over all connections and request each one to stop + foreach (T connection in Connections) + { + // Request to stop + connection.Dispose(); + } + // Clear the connections list + Connections.Clear(); + } + } + + public void Dispose() + { + // Stop the listener + Stop(); + } + /// /// Processes all incoming requests /// diff --git a/src/Microsoft.Data.SqlClient/tests/tools/TDS/TDS.EndPoint/TDSServerEndPointConnection.cs b/src/Microsoft.Data.SqlClient/tests/tools/TDS/TDS.EndPoint/TDSServerEndPointConnection.cs index 6327189691..84219e430c 100644 --- a/src/Microsoft.Data.SqlClient/tests/tools/TDS/TDS.EndPoint/TDSServerEndPointConnection.cs +++ b/src/Microsoft.Data.SqlClient/tests/tools/TDS/TDS.EndPoint/TDSServerEndPointConnection.cs @@ -8,6 +8,7 @@ using System.Net; using System.Net.Sockets; using System.Threading; +using System.Threading.Tasks; namespace Microsoft.SqlServer.TDS.EndPoint { @@ -44,12 +45,12 @@ public override void ProcessData(Stream rawStream) /// /// Connection to a single client /// - public abstract class ServerEndPointConnection + public abstract class ServerEndPointConnection : IDisposable { /// /// Worker thread /// - protected Thread ProcessorThread { get; set; } + protected Task ProcessorTask { get; set; } /// /// Gets/Sets the event log for the proxy server @@ -77,9 +78,9 @@ public abstract class ServerEndPointConnection protected TcpClient Connection { get; set; } /// - /// The flag indicates whether server is being stopped + /// Cancellation token source for managing cancellation of the processing thread /// - protected bool StopRequested { get; set; } + private CancellationTokenSource CancellationTokenSource = new CancellationTokenSource(); /// /// Initialization constructor @@ -124,13 +125,8 @@ public ServerEndPointConnection(ITDSServer server, TcpClient connection) /// internal void Start() { - // Start with active connection - StopRequested = false; - // Prepare and start a thread - ProcessorThread = new Thread(new ThreadStart(_ConnectionHandler)) { IsBackground = true }; - ProcessorThread.Name = string.Format("TDS Server Connection {0} Thread", Connection.Client.RemoteEndPoint); - ProcessorThread.Start(); + ProcessorTask = RunConnectionHandler(CancellationTokenSource.Token); } /// @@ -138,15 +134,7 @@ internal void Start() /// internal void Stop() { - // Request the listener thread to stop - StopRequested = true; - - // If connection failed to start there's no processor thread - if (ProcessorThread != null) - { - // Wait for termination - ProcessorThread.Join(); - } + CancellationTokenSource.Cancel(); } /// @@ -159,10 +147,22 @@ internal void Stop() /// public abstract void ProcessData(Stream rawStream); + public void Dispose() + { + Stop(); + + if (Connection != null) + { + Connection.Dispose(); + } + + CancellationTokenSource.Dispose(); + } + /// /// Worker thread /// - private void _ConnectionHandler() + private async Task RunConnectionHandler(CancellationToken cancellationToken) { try { @@ -171,7 +171,7 @@ private void _ConnectionHandler() PrepareForProcessingData(rawStream); // Process the packet sequence - while (Connection.Connected && !StopRequested) + while (Connection.Connected && !cancellationToken.IsCancellationRequested) { // Check incoming buffer if (rawStream.DataAvailable) @@ -187,7 +187,7 @@ private void _ConnectionHandler() } // Sleep a bit to reduce load on CPU - Thread.Sleep(10); + await Task.Delay(10); } } } @@ -212,6 +212,8 @@ private void _ConnectionHandler() { OnConnectionClosed(this, null); } + + return; } /// diff --git a/src/Microsoft.Data.SqlClient/tests/tools/TDS/TDS.Servers/AuthenticatingTDSServer.cs b/src/Microsoft.Data.SqlClient/tests/tools/TDS/TDS.Servers/AuthenticatingTdsServer.cs similarity index 96% rename from src/Microsoft.Data.SqlClient/tests/tools/TDS/TDS.Servers/AuthenticatingTDSServer.cs rename to src/Microsoft.Data.SqlClient/tests/tools/TDS/TDS.Servers/AuthenticatingTdsServer.cs index 06261e2c8f..2ed26f0beb 100644 --- a/src/Microsoft.Data.SqlClient/tests/tools/TDS/TDS.Servers/AuthenticatingTDSServer.cs +++ b/src/Microsoft.Data.SqlClient/tests/tools/TDS/TDS.Servers/AuthenticatingTdsServer.cs @@ -12,20 +12,20 @@ namespace Microsoft.SqlServer.TDS.Servers /// /// TDS Server that authenticates clients according to the requested parameters /// - public class AuthenticatingTDSServer : GenericTDSServer + public class AuthenticatingTdsServer : GenericTdsServer { /// /// Initialization constructor /// - public AuthenticatingTDSServer() : - this(new AuthenticatingTDSServerArguments()) + public AuthenticatingTdsServer() : + this(new AuthenticatingTdsServerArguments()) { } /// /// Initialization constructor /// - public AuthenticatingTDSServer(AuthenticatingTDSServerArguments arguments) : + public AuthenticatingTdsServer(AuthenticatingTdsServerArguments arguments) : base(arguments) { } @@ -39,10 +39,10 @@ public override TDSMessageCollection OnLogin7Request(ITDSServerSession session, TDSLogin7Token loginRequest = request[0] as TDSLogin7Token; // Check if arguments are of the authenticating TDS server - if (Arguments is AuthenticatingTDSServerArguments) + if (Arguments is AuthenticatingTdsServerArguments) { // Cast to authenticating TDS server arguments - AuthenticatingTDSServerArguments ServerArguments = Arguments as AuthenticatingTDSServerArguments; + AuthenticatingTdsServerArguments ServerArguments = Arguments as AuthenticatingTdsServerArguments; // Check if we're still processing normal login if (ServerArguments.ApplicationIntentFilter != ApplicationIntentFilterType.All) diff --git a/src/Microsoft.Data.SqlClient/tests/tools/TDS/TDS.Servers/AuthenticatingTDSServerArguments.cs b/src/Microsoft.Data.SqlClient/tests/tools/TDS/TDS.Servers/AuthenticatingTdsServerArguments.cs similarity index 54% rename from src/Microsoft.Data.SqlClient/tests/tools/TDS/TDS.Servers/AuthenticatingTDSServerArguments.cs rename to src/Microsoft.Data.SqlClient/tests/tools/TDS/TDS.Servers/AuthenticatingTdsServerArguments.cs index dcb812a648..52a1764c28 100644 --- a/src/Microsoft.Data.SqlClient/tests/tools/TDS/TDS.Servers/AuthenticatingTDSServerArguments.cs +++ b/src/Microsoft.Data.SqlClient/tests/tools/TDS/TDS.Servers/AuthenticatingTdsServerArguments.cs @@ -7,43 +7,31 @@ namespace Microsoft.SqlServer.TDS.Servers /// /// Arguments for authenticating TDS Server /// - public class AuthenticatingTDSServerArguments : TDSServerArguments + public class AuthenticatingTdsServerArguments : TdsServerArguments { /// /// Type of the application intent filter /// - public ApplicationIntentFilterType ApplicationIntentFilter { get; set; } + public ApplicationIntentFilterType ApplicationIntentFilter = ApplicationIntentFilterType.All; /// /// Filter for server name /// - public string ServerNameFilter { get; set; } + public string ServerNameFilter = string.Empty; /// /// Type of the filtering algorithm to use /// - public ServerNameFilterType ServerNameFilterType { get; set; } + public ServerNameFilterType ServerNameFilterType = ServerNameFilterType.None; /// /// TDS packet size filtering /// - public ushort? PacketSizeFilter { get; set; } + public ushort? PacketSizeFilter = null; /// /// Filter for application name /// - public string ApplicationNameFilter { get; set; } - - /// - /// Initialization constructor - /// - public AuthenticatingTDSServerArguments() - { - // Allow everyone to connect - ApplicationIntentFilter = ApplicationIntentFilterType.All; - - // By default we don't turn on server name filter - ServerNameFilterType = Servers.ServerNameFilterType.None; - } + public string ApplicationNameFilter = string.Empty; } } diff --git a/src/Microsoft.Data.SqlClient/tests/tools/TDS/TDS.Servers/FederatedAuthenticationNegativeTDSScenarioType.cs b/src/Microsoft.Data.SqlClient/tests/tools/TDS/TDS.Servers/FederatedAuthenticationNegativeTdsScenarioType.cs similarity index 94% rename from src/Microsoft.Data.SqlClient/tests/tools/TDS/TDS.Servers/FederatedAuthenticationNegativeTDSScenarioType.cs rename to src/Microsoft.Data.SqlClient/tests/tools/TDS/TDS.Servers/FederatedAuthenticationNegativeTdsScenarioType.cs index 11baa170d6..f35f69c22d 100644 --- a/src/Microsoft.Data.SqlClient/tests/tools/TDS/TDS.Servers/FederatedAuthenticationNegativeTDSScenarioType.cs +++ b/src/Microsoft.Data.SqlClient/tests/tools/TDS/TDS.Servers/FederatedAuthenticationNegativeTdsScenarioType.cs @@ -4,7 +4,7 @@ namespace Microsoft.SqlServer.TDS.Servers { - public enum FederatedAuthenticationNegativeTDSScenarioType : int + public enum FederatedAuthenticationNegativeTdsScenarioType : int { /// /// Valid Scenario. Do not perform negative activity. diff --git a/src/Microsoft.Data.SqlClient/tests/tools/TDS/TDS.Servers/FederatedAuthenticationNegativeTDSServer.cs b/src/Microsoft.Data.SqlClient/tests/tools/TDS/TDS.Servers/FederatedAuthenticationNegativeTdsServer.cs similarity index 85% rename from src/Microsoft.Data.SqlClient/tests/tools/TDS/TDS.Servers/FederatedAuthenticationNegativeTDSServer.cs rename to src/Microsoft.Data.SqlClient/tests/tools/TDS/TDS.Servers/FederatedAuthenticationNegativeTdsServer.cs index 40d4791f13..4ea2bb1f21 100644 --- a/src/Microsoft.Data.SqlClient/tests/tools/TDS/TDS.Servers/FederatedAuthenticationNegativeTDSServer.cs +++ b/src/Microsoft.Data.SqlClient/tests/tools/TDS/TDS.Servers/FederatedAuthenticationNegativeTdsServer.cs @@ -12,20 +12,20 @@ namespace Microsoft.SqlServer.TDS.Servers /// /// TDS Server that generates invalid TDS scenarios according to the requested parameters /// - public class FederatedAuthenticationNegativeTDSServer : GenericTDSServer + public class FederatedAuthenticationNegativeTdsServer : GenericTdsServer { /// /// Initialization constructor /// - public FederatedAuthenticationNegativeTDSServer() : - this(new FederatedAuthenticationNegativeTDSServerArguments()) + public FederatedAuthenticationNegativeTdsServer() : + this(new FederatedAuthenticationNegativeTdsServerArguments()) { } /// /// Initialization constructor /// - public FederatedAuthenticationNegativeTDSServer(FederatedAuthenticationNegativeTDSServerArguments arguments) : + public FederatedAuthenticationNegativeTdsServer(FederatedAuthenticationNegativeTdsServerArguments arguments) : base(arguments) { } @@ -39,10 +39,10 @@ public override TDSMessageCollection OnPreLoginRequest(ITDSServerSession session TDSMessageCollection preLoginCollection = base.OnPreLoginRequest(session, request); // Check if arguments are of the Federated Authentication server - if (Arguments is FederatedAuthenticationNegativeTDSServerArguments) + if (Arguments is FederatedAuthenticationNegativeTdsServerArguments) { // Cast to federated authentication server arguments - FederatedAuthenticationNegativeTDSServerArguments ServerArguments = Arguments as FederatedAuthenticationNegativeTDSServerArguments; + FederatedAuthenticationNegativeTdsServerArguments ServerArguments = Arguments as FederatedAuthenticationNegativeTdsServerArguments; // Find the is token carrying on TDSPreLoginToken TDSPreLoginToken preLoginToken = preLoginCollection.Find(message => message.Exists(packetToken => packetToken is TDSPreLoginToken)). @@ -50,7 +50,7 @@ public override TDSMessageCollection OnPreLoginRequest(ITDSServerSession session switch (ServerArguments.Scenario) { - case FederatedAuthenticationNegativeTDSScenarioType.NonceMissingInFedAuthPreLogin: + case FederatedAuthenticationNegativeTdsScenarioType.NonceMissingInFedAuthPreLogin: { // If we have the prelogin token if (preLoginToken != null && preLoginToken.Nonce != null) @@ -62,7 +62,7 @@ public override TDSMessageCollection OnPreLoginRequest(ITDSServerSession session break; } - case FederatedAuthenticationNegativeTDSScenarioType.InvalidB_FEDAUTHREQUIREDResponse: + case FederatedAuthenticationNegativeTdsScenarioType.InvalidB_FEDAUTHREQUIREDResponse: { // If we have the prelogin token if (preLoginToken != null) @@ -89,10 +89,10 @@ public override TDSMessageCollection OnLogin7Request(ITDSServerSession session, TDSMessageCollection login7Collection = base.OnLogin7Request(session, request); // Check if arguments are of the Federated Authentication server - if (Arguments is FederatedAuthenticationNegativeTDSServerArguments) + if (Arguments is FederatedAuthenticationNegativeTdsServerArguments) { // Cast to federated authentication server arguments - FederatedAuthenticationNegativeTDSServerArguments ServerArguments = Arguments as FederatedAuthenticationNegativeTDSServerArguments; + FederatedAuthenticationNegativeTdsServerArguments ServerArguments = Arguments as FederatedAuthenticationNegativeTdsServerArguments; // Get the Federated Authentication ExtAck from Login 7 TDSFeatureExtAckFederatedAuthenticationOption fedAutExtAct = GetFeatureExtAckFederatedAuthenticationOptionFromLogin7(login7Collection); @@ -102,21 +102,21 @@ public override TDSMessageCollection OnLogin7Request(ITDSServerSession session, { switch (ServerArguments.Scenario) { - case FederatedAuthenticationNegativeTDSScenarioType.NonceMissingInFedAuthFEATUREXTACK: + case FederatedAuthenticationNegativeTdsScenarioType.NonceMissingInFedAuthFEATUREXTACK: { // Delete the nonce from the Token fedAutExtAct.ClientNonce = null; break; } - case FederatedAuthenticationNegativeTDSScenarioType.FedAuthMissingInFEATUREEXTACK: + case FederatedAuthenticationNegativeTdsScenarioType.FedAuthMissingInFEATUREEXTACK: { // Remove the Fed Auth Ext Ack from the options list in the FeatureExtAckToken GetFeatureExtAckTokenFromLogin7(login7Collection).Options.Remove(fedAutExtAct); break; } - case FederatedAuthenticationNegativeTDSScenarioType.SignatureMissingInFedAuthFEATUREXTACK: + case FederatedAuthenticationNegativeTdsScenarioType.SignatureMissingInFedAuthFEATUREXTACK: { // Delete the signature from the Token fedAutExtAct.Signature = null; diff --git a/src/Microsoft.Data.SqlClient/tests/tools/TDS/TDS.Servers/FederatedAuthenticationNegativeTDSServerArguments.cs b/src/Microsoft.Data.SqlClient/tests/tools/TDS/TDS.Servers/FederatedAuthenticationNegativeTdsServerArguments.cs similarity index 56% rename from src/Microsoft.Data.SqlClient/tests/tools/TDS/TDS.Servers/FederatedAuthenticationNegativeTDSServerArguments.cs rename to src/Microsoft.Data.SqlClient/tests/tools/TDS/TDS.Servers/FederatedAuthenticationNegativeTdsServerArguments.cs index 67143d645b..34696c4f2e 100644 --- a/src/Microsoft.Data.SqlClient/tests/tools/TDS/TDS.Servers/FederatedAuthenticationNegativeTDSServerArguments.cs +++ b/src/Microsoft.Data.SqlClient/tests/tools/TDS/TDS.Servers/FederatedAuthenticationNegativeTdsServerArguments.cs @@ -7,18 +7,11 @@ namespace Microsoft.SqlServer.TDS.Servers /// /// Arguments for Fed Auth Negative TDS Server /// - public class FederatedAuthenticationNegativeTDSServerArguments : TDSServerArguments + public class FederatedAuthenticationNegativeTdsServerArguments : TdsServerArguments { /// /// Type of the Fed Auth Negative TDS Server /// - public FederatedAuthenticationNegativeTDSScenarioType Scenario { get; set; } - - /// - /// Initialization constructor - /// - public FederatedAuthenticationNegativeTDSServerArguments() - { - } + public FederatedAuthenticationNegativeTdsScenarioType Scenario = FederatedAuthenticationNegativeTdsScenarioType.ValidScenario; } } diff --git a/src/Microsoft.Data.SqlClient/tests/tools/TDS/TDS.Servers/GenericTDSServer.cs b/src/Microsoft.Data.SqlClient/tests/tools/TDS/TDS.Servers/GenericTdsServer.cs similarity index 92% rename from src/Microsoft.Data.SqlClient/tests/tools/TDS/TDS.Servers/GenericTDSServer.cs rename to src/Microsoft.Data.SqlClient/tests/tools/TDS/TDS.Servers/GenericTdsServer.cs index ac04fd2f57..6e6aa98ce1 100644 --- a/src/Microsoft.Data.SqlClient/tests/tools/TDS/TDS.Servers/GenericTDSServer.cs +++ b/src/Microsoft.Data.SqlClient/tests/tools/TDS/TDS.Servers/GenericTdsServer.cs @@ -4,6 +4,8 @@ using System; using System.Linq; +using System.Net; +using System.Runtime.CompilerServices; using System.Security.Cryptography; using System.Threading; using Microsoft.SqlServer.TDS.Authentication; @@ -25,7 +27,8 @@ namespace Microsoft.SqlServer.TDS.Servers /// /// Generic TDS server without specialization /// - public class GenericTDSServer : ITDSServer + public abstract class GenericTdsServer : ITDSServer, IDisposable + where T : TdsServerArguments { /// /// Delegate to be called when a LOGIN7 request has been received and is @@ -69,10 +72,19 @@ public delegate void OnAuthenticationCompletedDelegate( /// private int _sessionCount = 0; + /// + /// Counts pre-login requests to the server. + /// + private int _preLoginCount = 0; + + private TDSServerEndPoint _endpoint; + + public IPEndPoint EndPoint => _endpoint.ServerEndPoint; + /// /// Server configuration /// - protected TDSServerArguments Arguments { get; set; } + protected T Arguments { get; set; } /// /// Query engine instance @@ -80,17 +92,14 @@ public delegate void OnAuthenticationCompletedDelegate( protected QueryEngine Engine { get; set; } /// - /// Default constructor + /// Counts pre-login requests to the server. /// - public GenericTDSServer() : - this(new TDSServerArguments()) - { - } + public int PreLoginCount => _preLoginCount; /// /// Initialization constructor /// - public GenericTDSServer(TDSServerArguments arguments) : + public GenericTdsServer(T arguments) : this(arguments, new QueryEngine(arguments)) { } @@ -98,7 +107,7 @@ public GenericTDSServer(TDSServerArguments arguments) : /// /// Initialization constructor /// - public GenericTDSServer(TDSServerArguments arguments, QueryEngine queryEngine) + public GenericTdsServer(T arguments, QueryEngine queryEngine) { // Save arguments Arguments = arguments; @@ -110,6 +119,19 @@ public GenericTDSServer(TDSServerArguments arguments, QueryEngine queryEngine) Engine.Log = Arguments.Log; } + public void Start([CallerMemberName] string methodName = "") + { + _endpoint = new TDSServerEndPoint(this) { ServerEndPoint = new IPEndPoint(IPAddress.Any, 0) }; + _endpoint.EndpointName = methodName; + _endpoint.EventLog = Arguments.Log; + _endpoint.Start(); + } + + public void KillAllConnections() + { + _endpoint.KillAllConnections(); + } + /// /// Create a new session on the server /// @@ -120,7 +142,7 @@ public virtual ITDSServerSession OpenSession() Interlocked.Increment(ref _sessionCount); // Create a new session - GenericTDSServerSession session = new GenericTDSServerSession(this, (uint)_sessionCount); + GenericTdsServerSession session = new GenericTdsServerSession(this, (uint)_sessionCount); // Use configured encryption certificate and protocols session.EncryptionCertificate = Arguments.EncryptionCertificate; @@ -142,8 +164,11 @@ public virtual void CloseSession(ITDSServerSession session) /// public virtual TDSMessageCollection OnPreLoginRequest(ITDSServerSession session, TDSMessage request) { + Interlocked.Increment(ref _preLoginCount); + // Inflate pre-login request from the message TDSPreLoginToken preLoginRequest = request[0] as TDSPreLoginToken; + GenericTdsServerSession genericTdsServerSession = session as GenericTdsServerSession; // Log request TDSUtilities.Log(Arguments.Log, "Request", preLoginRequest); @@ -158,7 +183,7 @@ public virtual TDSMessageCollection OnPreLoginRequest(ITDSServerSession session, TDSPreLoginToken preLoginToken = new TDSPreLoginToken(Arguments.ServerVersion, serverResponse, false); // TDS server doesn't support MARS // Cache the received Nonce into the session - (session as GenericTDSServerSession).ClientNonce = preLoginRequest.Nonce; + genericTdsServerSession.ClientNonce = preLoginRequest.Nonce; // Check if the server has been started up as requiring FedAuth when choosing between SSPI and FedAuth if (Arguments.FedAuthRequiredPreLoginOption == TdsPreLoginFedAuthRequiredOption.FedAuthRequired) @@ -170,7 +195,7 @@ public virtual TDSMessageCollection OnPreLoginRequest(ITDSServerSession session, } // Keep the federated authentication required flag in the server session - (session as GenericTDSServerSession).FedAuthRequiredPreLoginServerResponse = preLoginToken.FedAuthRequired; + genericTdsServerSession.FedAuthRequiredPreLoginServerResponse = preLoginToken.FedAuthRequired; if (preLoginRequest.Nonce != null) { @@ -180,7 +205,7 @@ public virtual TDSMessageCollection OnPreLoginRequest(ITDSServerSession session, } // Cache the server Nonce in a session - (session as GenericTDSServerSession).ServerNonce = preLoginToken.Nonce; + genericTdsServerSession.ServerNonce = preLoginToken.Nonce; // Log response TDSUtilities.Log(Arguments.Log, "Response", preLoginToken); @@ -244,7 +269,7 @@ public virtual TDSMessageCollection OnLogin7Request(ITDSServerSession session, T TDSLogin7SessionRecoveryOptionToken sessionStateOption = option as TDSLogin7SessionRecoveryOptionToken; // Inflate session state - (session as GenericTDSServerSession).Inflate(sessionStateOption.Initial, sessionStateOption.Current); + (session as GenericTdsServerSession).Inflate(sessionStateOption.Initial, sessionStateOption.Current); break; } @@ -266,7 +291,7 @@ public virtual TDSMessageCollection OnLogin7Request(ITDSServerSession session, T } // Save the fed auth library to be used - (session as GenericTDSServerSession).FederatedAuthenticationLibrary = federatedAuthenticationOption.Library; + (session as GenericTdsServerSession).FederatedAuthenticationLibrary = federatedAuthenticationOption.Library; break; } @@ -542,7 +567,7 @@ protected virtual TDSMessageCollection OnAuthenticationCompleted(ITDSServerSessi responseMessage.Add(infoToken); // Create new collation change token - envChange = new TDSEnvChangeToken(TDSEnvChangeTokenType.SQLCollation, (session as GenericTDSServerSession).Collation); + envChange = new TDSEnvChangeToken(TDSEnvChangeTokenType.SQLCollation, (session as GenericTdsServerSession).Collation); // Log response TDSUtilities.Log(Arguments.Log, "Response", envChange); @@ -551,7 +576,7 @@ protected virtual TDSMessageCollection OnAuthenticationCompleted(ITDSServerSessi responseMessage.Add(envChange); // Create new language change token - envChange = new TDSEnvChangeToken(TDSEnvChangeTokenType.Language, LanguageString.ToString((session as GenericTDSServerSession).Language)); + envChange = new TDSEnvChangeToken(TDSEnvChangeTokenType.Language, LanguageString.ToString((session as GenericTdsServerSession).Language)); // Log response TDSUtilities.Log(Arguments.Log, "Response", envChange); @@ -593,7 +618,7 @@ protected virtual TDSMessageCollection OnAuthenticationCompleted(ITDSServerSessi if (session.IsSessionRecoveryEnabled) { // Create Feature extension Ack token - TDSFeatureExtAckToken featureExtActToken = new TDSFeatureExtAckToken(new TDSFeatureExtAckSessionStateOption((session as GenericTDSServerSession).Deflate())); + TDSFeatureExtAckToken featureExtActToken = new TDSFeatureExtAckToken(new TDSFeatureExtAckSessionStateOption((session as GenericTdsServerSession).Deflate())); // Log response TDSUtilities.Log(Arguments.Log, "Response", featureExtActToken); @@ -654,6 +679,16 @@ protected virtual TDSMessageCollection OnAuthenticationCompleted(ITDSServerSessi } } + if (!String.IsNullOrEmpty(Arguments.FailoverPartner)) + { + envChange = new TDSEnvChangeToken(TDSEnvChangeTokenType.RealTimeLogShipping, Arguments.FailoverPartner); + + // Log response + TDSUtilities.Log(Arguments.Log, "Response", envChange); + + responseMessage.Add(envChange); + } + // Create DONE token TDSDoneToken doneToken = new TDSDoneToken(TDSDoneTokenStatusType.Final); @@ -688,7 +723,7 @@ protected virtual TDSMessageCollection OnFederatedAuthenticationCompleted(ITDSSe try { // Get the Federated Authentication ticket using RPS - decryptedTicket = FederatedAuthenticationTicketService.DecryptTicket((session as GenericTDSServerSession).FederatedAuthenticationLibrary, ticket); + decryptedTicket = FederatedAuthenticationTicketService.DecryptTicket((session as GenericTdsServerSession).FederatedAuthenticationLibrary, ticket); if (decryptedTicket is RpsTicket) { @@ -719,17 +754,17 @@ protected virtual TDSMessageCollection OnFederatedAuthenticationCompleted(ITDSSe // Create federated authentication extension option TDSFeatureExtAckFederatedAuthenticationOption federatedAuthenticationOption; - if ((session as GenericTDSServerSession).FederatedAuthenticationLibrary == TDSFedAuthLibraryType.MSAL) + if ((session as GenericTdsServerSession).FederatedAuthenticationLibrary == TDSFedAuthLibraryType.MSAL) { // For the time being, fake fedauth tokens are used for ADAL, so decryptedTicket is null. federatedAuthenticationOption = - new TDSFeatureExtAckFederatedAuthenticationOption((session as GenericTDSServerSession).ClientNonce, null); + new TDSFeatureExtAckFederatedAuthenticationOption((session as GenericTdsServerSession).ClientNonce, null); } else { federatedAuthenticationOption = - new TDSFeatureExtAckFederatedAuthenticationOption((session as GenericTDSServerSession).ClientNonce, - decryptedTicket.GetSignature((session as GenericTDSServerSession).ClientNonce)); + new TDSFeatureExtAckFederatedAuthenticationOption((session as GenericTdsServerSession).ClientNonce, + decryptedTicket.GetSignature((session as GenericTdsServerSession).ClientNonce)); } // Look for feature extension token @@ -764,12 +799,12 @@ protected virtual TDSMessageCollection OnFederatedAuthenticationCompleted(ITDSSe protected virtual TDSMessageCollection CheckFederatedAuthenticationOption(ITDSServerSession session, TDSLogin7FedAuthOptionToken federatedAuthenticationOption) { // Check if server's prelogin response for FedAuthRequired prelogin option is echoed back correctly in FedAuth Feature Extenion Echo - if (federatedAuthenticationOption.Echo != (session as GenericTDSServerSession).FedAuthRequiredPreLoginServerResponse) + if (federatedAuthenticationOption.Echo != (session as GenericTdsServerSession).FedAuthRequiredPreLoginServerResponse) { // Create Error message string message = string.Format("FEDAUTHREQUIRED option in the prelogin response is not echoed back correctly: in prelogin response, it is {0} and in login, it is {1}: ", - (session as GenericTDSServerSession).FedAuthRequiredPreLoginServerResponse, + (session as GenericTdsServerSession).FedAuthRequiredPreLoginServerResponse, federatedAuthenticationOption.Echo); // Create errorToken token @@ -790,7 +825,7 @@ protected virtual TDSMessageCollection CheckFederatedAuthenticationOption(ITDSSe // Check if the nonce exists if ((federatedAuthenticationOption.Nonce == null && federatedAuthenticationOption.Library == TDSFedAuthLibraryType.IDCRL) - || !AreEqual((session as GenericTDSServerSession).ServerNonce, federatedAuthenticationOption.Nonce)) + || !AreEqual((session as GenericTdsServerSession).ServerNonce, federatedAuthenticationOption.Nonce)) { // Error message string message = string.Format("Unexpected NONCEOPT specified in the Federated authentication feature extension"); @@ -880,5 +915,7 @@ private bool AreEqual(byte[] left, byte[] right) return left.SequenceEqual(right); } + + public virtual void Dispose() => _endpoint?.Dispose(); } } diff --git a/src/Microsoft.Data.SqlClient/tests/tools/TDS/TDS.Servers/GenericTDSServerSession.cs b/src/Microsoft.Data.SqlClient/tests/tools/TDS/TDS.Servers/GenericTdsServerSession.cs similarity index 99% rename from src/Microsoft.Data.SqlClient/tests/tools/TDS/TDS.Servers/GenericTDSServerSession.cs rename to src/Microsoft.Data.SqlClient/tests/tools/TDS/TDS.Servers/GenericTdsServerSession.cs index e9e65d5f8f..2730fa02df 100644 --- a/src/Microsoft.Data.SqlClient/tests/tools/TDS/TDS.Servers/GenericTDSServerSession.cs +++ b/src/Microsoft.Data.SqlClient/tests/tools/TDS/TDS.Servers/GenericTdsServerSession.cs @@ -17,7 +17,7 @@ namespace Microsoft.SqlServer.TDS.Servers /// /// Generic session for TDS Server /// - public class GenericTDSServerSession : ITDSServerSession + public class GenericTdsServerSession : ITDSServerSession { /// /// Server that created the session @@ -259,7 +259,7 @@ public bool AnsiDefaults /// /// Initialization constructor /// - public GenericTDSServerSession(ITDSServer server, uint sessionID) : + public GenericTdsServerSession(ITDSServer server, uint sessionID) : this(server, sessionID, 4096) { } @@ -267,7 +267,7 @@ public GenericTDSServerSession(ITDSServer server, uint sessionID) : /// /// Initialization constructor /// - public GenericTDSServerSession(ITDSServer server, uint sessionID, uint packetSize) + public GenericTdsServerSession(ITDSServer server, uint sessionID, uint packetSize) { // Save the server Server = server; diff --git a/src/Microsoft.Data.SqlClient/tests/tools/TDS/TDS.Servers/QueryEngine.cs b/src/Microsoft.Data.SqlClient/tests/tools/TDS/TDS.Servers/QueryEngine.cs index eb219f5dbc..579c47abcc 100644 --- a/src/Microsoft.Data.SqlClient/tests/tools/TDS/TDS.Servers/QueryEngine.cs +++ b/src/Microsoft.Data.SqlClient/tests/tools/TDS/TDS.Servers/QueryEngine.cs @@ -26,12 +26,12 @@ public class QueryEngine /// /// Server configuration /// - public TDSServerArguments ServerArguments { get; private set; } + public TdsServerArguments ServerArguments { get; private set; } /// /// Initialization constructor /// - public QueryEngine(TDSServerArguments arguments) + public QueryEngine(TdsServerArguments arguments) { ServerArguments = arguments; } @@ -1308,7 +1308,7 @@ private TDSMessage _PrepareAnsiDefaultsResponse(ITDSServerSession session) TDSRowToken rowToken = new TDSRowToken(metadataToken); // Read the value from the session - rowToken.Data.Add((session as GenericTDSServerSession).AnsiDefaults); + rowToken.Data.Add((session as GenericTdsServerSession).AnsiDefaults); // Log response TDSUtilities.Log(Log, "Response", rowToken); @@ -1347,7 +1347,7 @@ private TDSMessage _PrepareAnsiNullDefaultOnResponse(ITDSServerSession session) TDSRowToken rowToken = new TDSRowToken(metadataToken); // Read the value from the session - rowToken.Data.Add((session as GenericTDSServerSession).AnsiNullDefaultOn); + rowToken.Data.Add((session as GenericTdsServerSession).AnsiNullDefaultOn); // Log response TDSUtilities.Log(Log, "Response", rowToken); @@ -1386,7 +1386,7 @@ private TDSMessage _PrepareAnsiNullsResponse(ITDSServerSession session) TDSRowToken rowToken = new TDSRowToken(metadataToken); // Read the value from the session - rowToken.Data.Add((session as GenericTDSServerSession).AnsiNulls); + rowToken.Data.Add((session as GenericTdsServerSession).AnsiNulls); // Log response TDSUtilities.Log(Log, "Response", rowToken); @@ -1425,7 +1425,7 @@ private TDSMessage _PrepareAnsiPaddingResponse(ITDSServerSession session) TDSRowToken rowToken = new TDSRowToken(metadataToken); // Read the value from the session - rowToken.Data.Add((session as GenericTDSServerSession).AnsiPadding); + rowToken.Data.Add((session as GenericTdsServerSession).AnsiPadding); // Log response TDSUtilities.Log(Log, "Response", rowToken); @@ -1464,7 +1464,7 @@ private TDSMessage _PrepareAnsiWarningsResponse(ITDSServerSession session) TDSRowToken rowToken = new TDSRowToken(metadataToken); // Read the value from the session - rowToken.Data.Add((session as GenericTDSServerSession).AnsiWarnings); + rowToken.Data.Add((session as GenericTdsServerSession).AnsiWarnings); // Log response TDSUtilities.Log(Log, "Response", rowToken); @@ -1503,7 +1503,7 @@ private TDSMessage _PrepareArithAbortResponse(ITDSServerSession session) TDSRowToken rowToken = new TDSRowToken(metadataToken); // Read the value from the session - rowToken.Data.Add((session as GenericTDSServerSession).ArithAbort); + rowToken.Data.Add((session as GenericTdsServerSession).ArithAbort); // Log response TDSUtilities.Log(Log, "Response", rowToken); @@ -1542,7 +1542,7 @@ private TDSMessage _PrepareConcatNullYieldsNullResponse(ITDSServerSession sessio TDSRowToken rowToken = new TDSRowToken(metadataToken); // Read the value from the session - rowToken.Data.Add((session as GenericTDSServerSession).ConcatNullYieldsNull); + rowToken.Data.Add((session as GenericTdsServerSession).ConcatNullYieldsNull); // Log response TDSUtilities.Log(Log, "Response", rowToken); @@ -1581,7 +1581,7 @@ private TDSMessage _PrepareDateFirstResponse(ITDSServerSession session) TDSRowToken rowToken = new TDSRowToken(metadataToken); // Read the value from the session - rowToken.Data.Add((short)(session as GenericTDSServerSession).DateFirst); + rowToken.Data.Add((short)(session as GenericTdsServerSession).DateFirst); // Log response TDSUtilities.Log(Log, "Response", rowToken); @@ -1622,7 +1622,7 @@ private TDSMessage _PrepareDateFormatResponse(ITDSServerSession session) TDSRowToken rowToken = new TDSRowToken(metadataToken); // Generate a date format string - rowToken.Data.Add(DateFormatString.ToString((session as GenericTDSServerSession).DateFormat)); + rowToken.Data.Add(DateFormatString.ToString((session as GenericTdsServerSession).DateFormat)); // Log response TDSUtilities.Log(Log, "Response", rowToken); @@ -1661,7 +1661,7 @@ private TDSMessage _PrepareDeadlockPriorityResponse(ITDSServerSession session) TDSRowToken rowToken = new TDSRowToken(metadataToken); // Serialize the value from the session - rowToken.Data.Add((session as GenericTDSServerSession).DeadlockPriority); + rowToken.Data.Add((session as GenericTdsServerSession).DeadlockPriority); // Log response TDSUtilities.Log(Log, "Response", rowToken); @@ -1702,7 +1702,7 @@ private TDSMessage _PrepareLanguageResponse(ITDSServerSession session) TDSRowToken rowToken = new TDSRowToken(metadataToken); // Generate a date format string - rowToken.Data.Add(LanguageString.ToString((session as GenericTDSServerSession).Language)); + rowToken.Data.Add(LanguageString.ToString((session as GenericTdsServerSession).Language)); // Log response TDSUtilities.Log(Log, "Response", rowToken); @@ -1741,7 +1741,7 @@ private TDSMessage _PrepareLockTimeoutResponse(ITDSServerSession session) TDSRowToken rowToken = new TDSRowToken(metadataToken); // Serialize the value from the session - rowToken.Data.Add((session as GenericTDSServerSession).LockTimeout); + rowToken.Data.Add((session as GenericTdsServerSession).LockTimeout); // Log response TDSUtilities.Log(Log, "Response", rowToken); @@ -1780,7 +1780,7 @@ private TDSMessage _PrepareQuotedIdentifierResponse(ITDSServerSession session) TDSRowToken rowToken = new TDSRowToken(metadataToken); // Read the value from the session - rowToken.Data.Add((session as GenericTDSServerSession).QuotedIdentifier); + rowToken.Data.Add((session as GenericTdsServerSession).QuotedIdentifier); // Log response TDSUtilities.Log(Log, "Response", rowToken); @@ -1819,7 +1819,7 @@ private TDSMessage _PrepareTextSizeResponse(ITDSServerSession session) TDSRowToken rowToken = new TDSRowToken(metadataToken); // Read the value from the session - rowToken.Data.Add((session as GenericTDSServerSession).TextSize); + rowToken.Data.Add((session as GenericTdsServerSession).TextSize); // Log response TDSUtilities.Log(Log, "Response", rowToken); @@ -1858,7 +1858,7 @@ private TDSMessage _PrepareTransactionIsolationLevelResponse(ITDSServerSession s TDSRowToken rowToken = new TDSRowToken(metadataToken); // Read the value from the session - rowToken.Data.Add((short)(session as GenericTDSServerSession).TransactionIsolationLevel); + rowToken.Data.Add((short)(session as GenericTdsServerSession).TransactionIsolationLevel); // Log response TDSUtilities.Log(Log, "Response", rowToken); @@ -1897,7 +1897,7 @@ private TDSMessage _PrepareOptionsResponse(ITDSServerSession session) TDSRowToken rowToken = new TDSRowToken(metadataToken); // Convert to generic session - GenericTDSServerSession genericSession = session as GenericTDSServerSession; + GenericTdsServerSession genericSession = session as GenericTdsServerSession; // Serialize the options into the bit mask int options = 0; @@ -2029,13 +2029,13 @@ private TDSMessage _PrepareContextInfoResponse(ITDSServerSession session) byte[] contextInfo = null; // Check if session has a context info - if ((session as GenericTDSServerSession).ContextInfo != null) + if ((session as GenericTdsServerSession).ContextInfo != null) { // Allocate a container contextInfo = new byte[128]; // Copy context info into the container - Array.Copy((session as GenericTDSServerSession).ContextInfo, contextInfo, (session as GenericTDSServerSession).ContextInfo.Length); + Array.Copy((session as GenericTdsServerSession).ContextInfo, contextInfo, (session as GenericTdsServerSession).ContextInfo.Length); } // Set context info diff --git a/src/Microsoft.Data.SqlClient/tests/tools/TDS/TDS.Servers/RoutingTDSServer.cs b/src/Microsoft.Data.SqlClient/tests/tools/TDS/TDS.Servers/RoutingTdsServer.cs similarity index 91% rename from src/Microsoft.Data.SqlClient/tests/tools/TDS/TDS.Servers/RoutingTDSServer.cs rename to src/Microsoft.Data.SqlClient/tests/tools/TDS/TDS.Servers/RoutingTdsServer.cs index 57596b24ac..8e119a54cd 100644 --- a/src/Microsoft.Data.SqlClient/tests/tools/TDS/TDS.Servers/RoutingTDSServer.cs +++ b/src/Microsoft.Data.SqlClient/tests/tools/TDS/TDS.Servers/RoutingTdsServer.cs @@ -16,20 +16,20 @@ namespace Microsoft.SqlServer.TDS.Servers /// /// TDS Server that routes clients to the configured destination /// - public class RoutingTDSServer : GenericTDSServer + public class RoutingTdsServer : GenericTdsServer { /// /// Initialization constructor /// - public RoutingTDSServer() : - this(new RoutingTDSServerArguments()) + public RoutingTdsServer() : + this(new RoutingTdsServerArguments()) { } /// /// Initialization constructor /// - public RoutingTDSServer(RoutingTDSServerArguments arguments) : + public RoutingTdsServer(RoutingTdsServerArguments arguments) : base(arguments) { } @@ -43,10 +43,10 @@ public override TDSMessageCollection OnPreLoginRequest(ITDSServerSession session TDSMessageCollection response = base.OnPreLoginRequest(session, request); // Check if arguments are of the routing server - if (Arguments is RoutingTDSServerArguments) + if (Arguments is RoutingTdsServerArguments) { // Cast to routing server arguments - RoutingTDSServerArguments serverArguments = Arguments as RoutingTDSServerArguments; + RoutingTdsServerArguments serverArguments = Arguments as RoutingTdsServerArguments; // Check if routing is configured during login if (serverArguments.RouteOnPacket == TDSMessageType.TDS7Login) @@ -78,10 +78,10 @@ public override TDSMessageCollection OnLogin7Request(ITDSServerSession session, TDSLogin7Token loginRequest = request[0] as TDSLogin7Token; // Check if arguments are of the routing server - if (Arguments is RoutingTDSServerArguments) + if (Arguments is RoutingTdsServerArguments) { // Cast to routing server arguments - RoutingTDSServerArguments ServerArguments = Arguments as RoutingTDSServerArguments; + RoutingTdsServerArguments ServerArguments = Arguments as RoutingTdsServerArguments; // Check filter if (ServerArguments.RequireReadOnly && (loginRequest.TypeFlags.ReadOnlyIntent != TDSLogin7TypeFlagsReadOnlyIntent.ReadOnly)) @@ -136,10 +136,10 @@ public override TDSMessageCollection OnSQLBatchRequest(ITDSServerSession session TDSMessageCollection batchResponse = base.OnSQLBatchRequest(session, request); // Check if arguments are of routing server - if (Arguments is RoutingTDSServerArguments) + if (Arguments is RoutingTdsServerArguments) { // Cast to routing server arguments - RoutingTDSServerArguments ServerArguments = Arguments as RoutingTDSServerArguments; + RoutingTdsServerArguments ServerArguments = Arguments as RoutingTdsServerArguments; // Check routing condition if (ServerArguments.RouteOnPacket == TDSMessageType.SQLBatch) @@ -188,10 +188,10 @@ protected override TDSMessageCollection OnAuthenticationCompleted(ITDSServerSess TDSMessageCollection responseMessageCollection = base.OnAuthenticationCompleted(session); // Check if arguments are of routing server - if (Arguments is RoutingTDSServerArguments) + if (Arguments is RoutingTdsServerArguments) { // Cast to routing server arguments - RoutingTDSServerArguments serverArguments = Arguments as RoutingTDSServerArguments; + RoutingTdsServerArguments serverArguments = Arguments as RoutingTdsServerArguments; // Check routing condition if (serverArguments.RouteOnPacket == TDSMessageType.TDS7Login) @@ -233,7 +233,7 @@ protected override TDSMessageCollection OnAuthenticationCompleted(ITDSServerSess protected TDSPacketToken CreateRoutingToken() { // Cast to routing server arguments - RoutingTDSServerArguments ServerArguments = Arguments as RoutingTDSServerArguments; + RoutingTdsServerArguments ServerArguments = Arguments as RoutingTdsServerArguments; // Construct routing token value TDSRoutingEnvChangeTokenValue routingInfo = new TDSRoutingEnvChangeTokenValue(); diff --git a/src/Microsoft.Data.SqlClient/tests/tools/TDS/TDS.Servers/RoutingTDSServerArguments.cs b/src/Microsoft.Data.SqlClient/tests/tools/TDS/TDS.Servers/RoutingTdsServerArguments.cs similarity index 51% rename from src/Microsoft.Data.SqlClient/tests/tools/TDS/TDS.Servers/RoutingTDSServerArguments.cs rename to src/Microsoft.Data.SqlClient/tests/tools/TDS/TDS.Servers/RoutingTdsServerArguments.cs index 99cbd3baae..056bd33b79 100644 --- a/src/Microsoft.Data.SqlClient/tests/tools/TDS/TDS.Servers/RoutingTDSServerArguments.cs +++ b/src/Microsoft.Data.SqlClient/tests/tools/TDS/TDS.Servers/RoutingTdsServerArguments.cs @@ -7,43 +7,31 @@ namespace Microsoft.SqlServer.TDS.Servers /// /// Arguments for routing TDS Server /// - public class RoutingTDSServerArguments : TDSServerArguments + public class RoutingTdsServerArguments : TdsServerArguments { /// - /// Routing destination protocol + /// Routing destination protocol. /// - public int RoutingProtocol { get; set; } + public int RoutingProtocol = 0; /// /// Routing TCP port /// - public ushort RoutingTCPPort { get; set; } + public ushort RoutingTCPPort = 0; /// /// Routing TCP host name /// - public string RoutingTCPHost { get; set; } + public string RoutingTCPHost = string.Empty; /// /// Packet on which routing should occur /// - public TDSMessageType RouteOnPacket { get; set; } + public TDSMessageType RouteOnPacket = TDSMessageType.TDS7Login; /// /// Indicates that routing should only occur on read-only connections /// - public bool RequireReadOnly { get; set; } - - /// - /// Initialization constructor - /// - public RoutingTDSServerArguments() - { - // By default we route on login - RouteOnPacket = TDSMessageType.TDS7Login; - - // By default we reject non-read-only connections - RequireReadOnly = true; - } + public bool RequireReadOnly = true; } } diff --git a/src/Microsoft.Data.SqlClient/tests/tools/TDS/TDS.Servers/TDS.Servers.csproj b/src/Microsoft.Data.SqlClient/tests/tools/TDS/TDS.Servers/TDS.Servers.csproj index b7757b257b..6dc40e4d1b 100644 --- a/src/Microsoft.Data.SqlClient/tests/tools/TDS/TDS.Servers/TDS.Servers.csproj +++ b/src/Microsoft.Data.SqlClient/tests/tools/TDS/TDS.Servers/TDS.Servers.csproj @@ -11,21 +11,24 @@ - - + + - - - - - + + + + + - - + + - - - + + + + + + diff --git a/src/Microsoft.Data.SqlClient/tests/tools/TDS/TDS.Servers/TdsServer.cs b/src/Microsoft.Data.SqlClient/tests/tools/TDS/TDS.Servers/TdsServer.cs new file mode 100644 index 0000000000..fef5b851c8 --- /dev/null +++ b/src/Microsoft.Data.SqlClient/tests/tools/TDS/TDS.Servers/TdsServer.cs @@ -0,0 +1,31 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +namespace Microsoft.SqlServer.TDS.Servers +{ + public class TdsServer : GenericTdsServer + { + /// + /// Default constructor + /// + public TdsServer() : this(new TdsServerArguments()) + { + } + /// + /// Constructor with arguments + /// + public TdsServer(TdsServerArguments arguments) : base(arguments) + { + } + + /// + /// Constructor with arguments and query engine + /// + /// Query engine + /// Server arguments + public TdsServer(QueryEngine queryEngine, TdsServerArguments arguments) : base(arguments, queryEngine) + { + } + } +} diff --git a/src/Microsoft.Data.SqlClient/tests/tools/TDS/TDS.Servers/TDSServerArguments.cs b/src/Microsoft.Data.SqlClient/tests/tools/TDS/TDS.Servers/TdsServerArguments.cs similarity index 59% rename from src/Microsoft.Data.SqlClient/tests/tools/TDS/TDS.Servers/TDSServerArguments.cs rename to src/Microsoft.Data.SqlClient/tests/tools/TDS/TDS.Servers/TdsServerArguments.cs index 88e577ab68..51adc129c1 100644 --- a/src/Microsoft.Data.SqlClient/tests/tools/TDS/TDS.Servers/TDSServerArguments.cs +++ b/src/Microsoft.Data.SqlClient/tests/tools/TDS/TDS.Servers/TdsServerArguments.cs @@ -13,7 +13,7 @@ namespace Microsoft.SqlServer.TDS.Servers /// /// Common arguments for TDS Server /// - public class TDSServerArguments + public class TdsServerArguments { /// /// Service Principal Name, representing Azure SQL Database in Azure Active Directory. @@ -28,76 +28,56 @@ public class TDSServerArguments /// /// Log to which send TDS conversation /// - public TextWriter Log { get; set; } + public TextWriter Log = null; /// /// Server name /// - public string ServerName { get; set; } + public string ServerName = Environment.MachineName; /// /// Server version /// - public Version ServerVersion { get; set; } + public Version ServerVersion = new Version(11, 0, 1083); /// /// Server principal name /// - public string ServerPrincipalName { get; set; } + public string ServerPrincipalName = AzureADServicePrincipalName; /// /// Sts Url /// - public string StsUrl { get; set; } + public string StsUrl = AzureADProductionTokenEndpoint; /// /// Size of the TDS packet server should operate with /// - public int PacketSize { get; set; } + public int PacketSize = 4096; /// /// Transport encryption /// - public TDSPreLoginTokenEncryptionType Encryption { get; set; } + public TDSPreLoginTokenEncryptionType Encryption = TDSPreLoginTokenEncryptionType.NotSupported; /// /// Specifies the FedAuthRequired option /// - public TdsPreLoginFedAuthRequiredOption FedAuthRequiredPreLoginOption { get; set; } + public TdsPreLoginFedAuthRequiredOption FedAuthRequiredPreLoginOption = TdsPreLoginFedAuthRequiredOption.FedAuthNotRequired; /// /// Certificate to use for transport encryption /// - public X509Certificate EncryptionCertificate { get; set; } + public X509Certificate EncryptionCertificate = null; /// /// SSL/TLS protocols to use for transport encryption /// - public SslProtocols EncryptionProtocols { get; set; } + public SslProtocols EncryptionProtocols = SslProtocols.Tls12; /// - /// Initialization constructor + /// Routing destination protocol /// - public TDSServerArguments() - { - // Assign default server version - ServerName = Environment.MachineName; - ServerVersion = new Version(11, 0, 1083); - - // Default packet size - PacketSize = 4096; - - // By default we don't support encryption - Encryption = TDSPreLoginTokenEncryptionType.NotSupported; - - // By Default SQL authentication will be used. - FedAuthRequiredPreLoginOption = TdsPreLoginFedAuthRequiredOption.FedAuthNotRequired; - - EncryptionCertificate = null; - EncryptionProtocols = SslProtocols.Tls12; - - ServerPrincipalName = AzureADServicePrincipalName; - StsUrl = AzureADProductionTokenEndpoint; - } + public string FailoverPartner = string.Empty; } } diff --git a/src/Microsoft.Data.SqlClient/tests/tools/TDS/TDS.Servers/TransientDelayTdsServer.cs b/src/Microsoft.Data.SqlClient/tests/tools/TDS/TDS.Servers/TransientDelayTdsServer.cs new file mode 100644 index 0000000000..fb4328700f --- /dev/null +++ b/src/Microsoft.Data.SqlClient/tests/tools/TDS/TDS.Servers/TransientDelayTdsServer.cs @@ -0,0 +1,82 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using System.Threading; +using Microsoft.SqlServer.TDS.EndPoint; + +namespace Microsoft.SqlServer.TDS.Servers +{ + /// + /// TDS Server that authenticates clients according to the requested parameters + /// + public class TransientDelayTdsServer : GenericTdsServer, IDisposable + { + private static int RequestCounter = 0; + + public TransientDelayTdsServer(TransientDelayTdsServerArguments arguments) : base(arguments) + { + } + + public TransientDelayTdsServer(TransientDelayTdsServerArguments arguments, QueryEngine queryEngine) : base(arguments, queryEngine) + { + } + + public void ResetRequestCounter() + { + RequestCounter = 0; + } + + public void SetTransientTimeoutBehavior(bool isEnabledTransientTimeout, TimeSpan sleepDuration) + { + SetTransientTimeoutBehavior(isEnabledTransientTimeout, false, sleepDuration); + } + + public void SetTransientTimeoutBehavior(bool isEnabledTransientTimeout, bool isEnabledPermanentTimeout, TimeSpan sleepDuration) + { + Arguments.IsEnabledTransientTimeout = isEnabledTransientTimeout; + Arguments.IsEnabledPermanentTimeout = isEnabledPermanentTimeout; + Arguments.SleepDuration = sleepDuration; + } + + /// + /// Handler for login request + /// + public override TDSMessageCollection OnLogin7Request(ITDSServerSession session, TDSMessage request) + { + // Check if we're still going to raise transient error + if (Arguments.IsEnabledPermanentTimeout || + (Arguments.IsEnabledTransientTimeout && RequestCounter < Arguments.RepeatCount)) + { + Thread.Sleep(Arguments.SleepDuration); + + RequestCounter++; + } + + // Return login response from the base class + return base.OnLogin7Request(session, request); + } + + /// + public override TDSMessageCollection OnSQLBatchRequest(ITDSServerSession session, TDSMessage message) + { + if (Arguments.IsEnabledPermanentTimeout || + (Arguments.IsEnabledTransientTimeout && RequestCounter < 1)) + { + Thread.Sleep(Arguments.SleepDuration); + + RequestCounter++; + } + + return base.OnSQLBatchRequest(session, message); + } + + /// + public override void Dispose() + { + base.Dispose(); + RequestCounter = 0; + } + } +} diff --git a/src/Microsoft.Data.SqlClient/tests/tools/TDS/TDS.Servers/TransientDelayTdsServerArguments.cs b/src/Microsoft.Data.SqlClient/tests/tools/TDS/TDS.Servers/TransientDelayTdsServerArguments.cs new file mode 100644 index 0000000000..020fbd973d --- /dev/null +++ b/src/Microsoft.Data.SqlClient/tests/tools/TDS/TDS.Servers/TransientDelayTdsServerArguments.cs @@ -0,0 +1,31 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; + +namespace Microsoft.SqlServer.TDS.Servers +{ + public class TransientDelayTdsServerArguments : TdsServerArguments + { + /// + /// The duration for which the server should sleep before responding to a request. + /// + public TimeSpan SleepDuration = TimeSpan.FromSeconds(0); + + /// + /// Flag to consider when simulating a timeout on the next request. + /// + public bool IsEnabledTransientTimeout = false; + + /// + /// Flag to consider when simulating a timeout on each request. + /// + public bool IsEnabledPermanentTimeout = false; + + /// + /// The number of times the transient error should be raised. + /// + public int RepeatCount = 1; + } +} diff --git a/src/Microsoft.Data.SqlClient/tests/tools/TDS/TDS.Servers/TransientFaultTDSServer.cs b/src/Microsoft.Data.SqlClient/tests/tools/TDS/TDS.Servers/TransientFaultTDSServer.cs deleted file mode 100644 index 1933444df6..0000000000 --- a/src/Microsoft.Data.SqlClient/tests/tools/TDS/TDS.Servers/TransientFaultTDSServer.cs +++ /dev/null @@ -1,153 +0,0 @@ -// Licensed to the .NET Foundation under one or more agreements. -// The .NET Foundation licenses this file to you under the MIT license. -// See the LICENSE file in the project root for more information. - -using System; -using System.Net; -using System.Runtime.CompilerServices; -using System.Threading; -using Microsoft.SqlServer.TDS.Done; -using Microsoft.SqlServer.TDS.EndPoint; -using Microsoft.SqlServer.TDS.Error; -using Microsoft.SqlServer.TDS.Login7; - -namespace Microsoft.SqlServer.TDS.Servers -{ - /// - /// TDS Server that authenticates clients according to the requested parameters - /// - public class TransientFaultTDSServer : GenericTDSServer, IDisposable - { - private static int RequestCounter = 0; - - public int Port { get; set; } - - /// - /// Constructor - /// - public TransientFaultTDSServer() => new TransientFaultTDSServer(new TransientFaultTDSServerArguments()); - - /// - /// Constructor - /// - /// - public TransientFaultTDSServer(TransientFaultTDSServerArguments arguments) : - base(arguments) - { } - - /// - /// Constructor - /// - /// - /// - public TransientFaultTDSServer(QueryEngine engine, TransientFaultTDSServerArguments args) : base(args) - { - Engine = engine; - } - - private TDSServerEndPoint _endpoint = null; - - private static string GetErrorMessage(uint errorNumber) - { - switch (errorNumber) - { - case 40613: - return "Database on server is not currently available. Please retry the connection later. " + - "If the problem persists, contact customer support, and provide them the session tracing ID."; - case 42108: - return "Can not connect to the SQL pool since it is paused. Please resume the SQL pool and try again."; - case 42109: - return "The SQL pool is warming up. Please try again."; - } - return "Unknown server error occurred"; - } - - /// - /// Handler for login request - /// - public override TDSMessageCollection OnLogin7Request(ITDSServerSession session, TDSMessage request) - { - // Inflate login7 request from the message - TDSLogin7Token loginRequest = request[0] as TDSLogin7Token; - - // Check if arguments are of the transient fault TDS server - if (Arguments is TransientFaultTDSServerArguments) - { - // Cast to transient fault TDS server arguments - TransientFaultTDSServerArguments ServerArguments = Arguments as TransientFaultTDSServerArguments; - - // Check if we're still going to raise transient error - if (ServerArguments.IsEnabledTransientError && RequestCounter < 1) // Fail first time, then connect - { - uint errorNumber = ServerArguments.Number; - string errorMessage = ServerArguments.Message; - - // Log request to which we're about to send a failure - TDSUtilities.Log(Arguments.Log, "Request", loginRequest); - - // Prepare ERROR token with the denial details - TDSErrorToken errorToken = new TDSErrorToken(errorNumber, 1, 20, errorMessage); - - // Log response - TDSUtilities.Log(Arguments.Log, "Response", errorToken); - - // Serialize the error token into the response packet - TDSMessage responseMessage = new TDSMessage(TDSMessageType.Response, errorToken); - - // Create DONE token - TDSDoneToken doneToken = new TDSDoneToken(TDSDoneTokenStatusType.Final | TDSDoneTokenStatusType.Error); - - // Log response - TDSUtilities.Log(Arguments.Log, "Response", doneToken); - - // Serialize DONE token into the response packet - responseMessage.Add(doneToken); - - RequestCounter++; - - // Put a single message into the collection and return it - return new TDSMessageCollection(responseMessage); - } - } - - // Return login response from the base class - return base.OnLogin7Request(session, request); - } - - public static TransientFaultTDSServer StartTestServer(bool isEnabledTransientFault, bool enableLog, uint errorNumber, [CallerMemberName] string methodName = "") - => StartServerWithQueryEngine(null, isEnabledTransientFault, enableLog, errorNumber, methodName); - - public static TransientFaultTDSServer StartServerWithQueryEngine(QueryEngine engine, bool isEnabledTransientFault, bool enableLog, uint errorNumber, [CallerMemberName] string methodName = "") - { - TransientFaultTDSServerArguments args = new TransientFaultTDSServerArguments() - { - Log = enableLog ? Console.Out : null, - IsEnabledTransientError = isEnabledTransientFault, - Number = errorNumber, - Message = GetErrorMessage(errorNumber) - }; - - TransientFaultTDSServer server = engine == null ? new TransientFaultTDSServer(args) : new TransientFaultTDSServer(engine, args); - server._endpoint = new TDSServerEndPoint(server) { ServerEndPoint = new IPEndPoint(IPAddress.Any, 0) }; - server._endpoint.EndpointName = methodName; - - // The server EventLog should be enabled as it logs the exceptions. - server._endpoint.EventLog = enableLog ? Console.Out : null; - server._endpoint.Start(); - - server.Port = server._endpoint.ServerEndPoint.Port; - return server; - } - - public void Dispose() => Dispose(true); - - private void Dispose(bool isDisposing) - { - if (isDisposing) - { - _endpoint?.Stop(); - RequestCounter = 0; - } - } - } -} diff --git a/src/Microsoft.Data.SqlClient/tests/tools/TDS/TDS.Servers/TransientFaultTdsServer.cs b/src/Microsoft.Data.SqlClient/tests/tools/TDS/TDS.Servers/TransientFaultTdsServer.cs new file mode 100644 index 0000000000..6b8c839a6a --- /dev/null +++ b/src/Microsoft.Data.SqlClient/tests/tools/TDS/TDS.Servers/TransientFaultTdsServer.cs @@ -0,0 +1,101 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using Microsoft.SqlServer.TDS.Done; +using Microsoft.SqlServer.TDS.EndPoint; +using Microsoft.SqlServer.TDS.Error; +using Microsoft.SqlServer.TDS.Login7; + +namespace Microsoft.SqlServer.TDS.Servers +{ + /// + /// TDS Server that authenticates clients according to the requested parameters + /// + public class TransientFaultTdsServer : GenericTdsServer, IDisposable + { + private int RequestCounter = 0; + + public void SetErrorBehavior(bool isEnabledTransientFault, uint errorNumber, int repeatCount = 1, string message = null) + { + Arguments.IsEnabledTransientError = isEnabledTransientFault; + Arguments.Number = errorNumber; + Arguments.Message = message; + Arguments.RepeatCount = repeatCount; + } + + public TransientFaultTdsServer(TransientFaultTdsServerArguments arguments) : base(arguments) + { + } + + public TransientFaultTdsServer(TransientFaultTdsServerArguments arguments, QueryEngine queryEngine) : base(arguments, queryEngine) + { + } + + private static string GetErrorMessage(uint errorNumber) + { + switch (errorNumber) + { + case 40613: + return "Database on server is not currently available. Please retry the connection later. " + + "If the problem persists, contact customer support, and provide them the session tracing ID."; + case 42108: + return "Can not connect to the SQL pool since it is paused. Please resume the SQL pool and try again."; + case 42109: + return "The SQL pool is warming up. Please try again."; + } + return "Unknown server error occurred"; + } + + /// + /// Handler for login request + /// + public override TDSMessageCollection OnLogin7Request(ITDSServerSession session, TDSMessage request) + { + // Inflate login7 request from the message + TDSLogin7Token loginRequest = request[0] as TDSLogin7Token; + + // Check if we're still going to raise transient error + if (Arguments.IsEnabledTransientError && RequestCounter < Arguments.RepeatCount) + { + uint errorNumber = Arguments.Number; + string errorMessage = Arguments.Message ?? GetErrorMessage(errorNumber); + + // Log request to which we're about to send a failure + TDSUtilities.Log(Arguments.Log, "Request", loginRequest); + + // Prepare ERROR token with the denial details + TDSErrorToken errorToken = new TDSErrorToken(errorNumber, 1, 20, errorMessage); + + // Log response + TDSUtilities.Log(Arguments.Log, "Response", errorToken); + + // Serialize the error token into the response packet + TDSMessage responseMessage = new TDSMessage(TDSMessageType.Response, errorToken); + + // Create DONE token + TDSDoneToken doneToken = new TDSDoneToken(TDSDoneTokenStatusType.Final | TDSDoneTokenStatusType.Error); + + // Log response + TDSUtilities.Log(Arguments.Log, "Response", doneToken); + + // Serialize DONE token into the response packet + responseMessage.Add(doneToken); + + RequestCounter++; + + // Put a single message into the collection and return it + return new TDSMessageCollection(responseMessage); + } + + // Return login response from the base class + return base.OnLogin7Request(session, request); + } + + public override void Dispose() { + base.Dispose(); + RequestCounter = 0; + } + } +} diff --git a/src/Microsoft.Data.SqlClient/tests/tools/TDS/TDS.Servers/TransientFaultTDSServerArguments.cs b/src/Microsoft.Data.SqlClient/tests/tools/TDS/TDS.Servers/TransientFaultTdsServerArguments.cs similarity index 60% rename from src/Microsoft.Data.SqlClient/tests/tools/TDS/TDS.Servers/TransientFaultTDSServerArguments.cs rename to src/Microsoft.Data.SqlClient/tests/tools/TDS/TDS.Servers/TransientFaultTdsServerArguments.cs index 77eec68c5f..22ba9e83cb 100644 --- a/src/Microsoft.Data.SqlClient/tests/tools/TDS/TDS.Servers/TransientFaultTDSServerArguments.cs +++ b/src/Microsoft.Data.SqlClient/tests/tools/TDS/TDS.Servers/TransientFaultTdsServerArguments.cs @@ -4,31 +4,26 @@ namespace Microsoft.SqlServer.TDS.Servers { - public class TransientFaultTDSServerArguments : TDSServerArguments + public class TransientFaultTdsServerArguments : TdsServerArguments { /// /// Transient error number to be raised by server. /// - public uint Number { get; set; } + public uint Number = 0; /// /// Transient error message to be raised by server. /// - public string Message { get; set; } + public string Message = string.Empty; /// /// Flag to consider when raising Transient error. /// - public bool IsEnabledTransientError { get; set; } + public bool IsEnabledTransientError = false; /// - /// Constructor to initialize + /// The number of times the transient error should be raised. /// - public TransientFaultTDSServerArguments() - { - Number = 0; - Message = string.Empty; - IsEnabledTransientError = false; - } + public int RepeatCount = 1; } }