diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlCommand.netcore.cs b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlCommand.netcore.cs index 90979eeec9..07098a3792 100644 --- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlCommand.netcore.cs +++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlCommand.netcore.cs @@ -2236,7 +2236,7 @@ private void CheckNotificationStateAndAutoEnlist() } Notification.Options = SqlDependency.GetDefaultComposedOptions(_activeConnection.DataSource, - InternalTdsConnection.ServerProvidedFailOverPartner, + InternalTdsConnection.ServerProvidedFailoverPartner, identityUserName, _activeConnection.Database); } diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlInternalConnectionTds.cs b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlInternalConnectionTds.cs index 1871fe6087..6eedfd2fa9 100644 --- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlInternalConnectionTds.cs +++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlInternalConnectionTds.cs @@ -314,7 +314,6 @@ internal SessionData CurrentSessionData // FOR CONNECTION RESET MANAGEMENT private bool _fResetConnection; private string _originalDatabase; - private string _currentFailoverPartner; // only set by ENV change from server private string _originalLanguage; private string _currentLanguage; private int _currentPacketSize; @@ -704,13 +703,7 @@ internal TdsParser Parser } } - internal string ServerProvidedFailOverPartner - { - get - { - return _currentFailoverPartner; - } - } + internal string ServerProvidedFailoverPartner { get; set; } internal SqlConnectionPoolGroupProviderInfo PoolGroupProviderInfo { @@ -1507,7 +1500,7 @@ private void OpenLoginEnlist(TimeoutTimer timeout, throw SQL.ROR_FailoverNotSupportedConnString(); } - if (ServerProvidedFailOverPartner != null) + if (ServerProvidedFailoverPartner != null) { throw SQL.ROR_FailoverNotSupportedServer(this); } @@ -1635,7 +1628,7 @@ private void LoginNoFailover(ServerInfo serverInfo, newSecurePassword, attemptOneLoginTimeout); - if (connectionOptions.MultiSubnetFailover && ServerProvidedFailOverPartner != null) + if (connectionOptions.MultiSubnetFailover && ServerProvidedFailoverPartner != null) { // connection succeeded: trigger exception if server sends failover partner and MultiSubnetFailover is used throw SQL.MultiSubnetFailoverWithFailoverPartner(serverProvidedFailoverPartner: true, internalConnection: this); @@ -1663,7 +1656,7 @@ private void LoginNoFailover(ServerInfo serverInfo, _currentPacketSize = ConnectionOptions.PacketSize; _currentLanguage = _originalLanguage = ConnectionOptions.CurrentLanguage; CurrentDatabase = _originalDatabase = ConnectionOptions.InitialCatalog; - _currentFailoverPartner = null; + ServerProvidedFailoverPartner = null; _instanceName = string.Empty; routingAttempts++; @@ -1702,7 +1695,7 @@ private void LoginNoFailover(ServerInfo serverInfo, // We only get here when we failed to connect, but are going to re-try // Switch to failover logic if the server provided a partner - if (ServerProvidedFailOverPartner != null) + if (ServerProvidedFailoverPartner != null) { if (connectionOptions.MultiSubnetFailover) { @@ -1718,7 +1711,7 @@ private void LoginNoFailover(ServerInfo serverInfo, LoginWithFailover( true, // start by using failover partner, since we already failed to connect to the primary serverInfo, - ServerProvidedFailOverPartner, + ServerProvidedFailoverPartner, newPassword, newSecurePassword, redirectedUserInstance, @@ -1740,8 +1733,13 @@ private void LoginNoFailover(ServerInfo serverInfo, { // We must wait for CompleteLogin to finish for to have the // env change from the server to know its designated failover - // partner; save this information in _currentFailoverPartner. - PoolGroupProviderInfo.FailoverCheck(false, connectionOptions, ServerProvidedFailOverPartner); + // partner; save this information in ServerProvidedFailoverPartner. + + // When ignoring server provided failover partner, we must pass in the original failover partner from the connection string. + // Otherwise the pool group's failover partner designation will be updated to point to the server provided value. + string actualFailoverPartner = LocalAppContextSwitches.IgnoreServerProvidedFailoverPartner ? "" : ServerProvidedFailoverPartner; + + PoolGroupProviderInfo.FailoverCheck(false, connectionOptions, actualFailoverPartner); } CurrentDataSource = originalServerInfo.UserServerName; } @@ -1802,7 +1800,7 @@ TimeoutTimer timeout ServerInfo failoverServerInfo = new ServerInfo(connectionOptions, failoverHost, connectionOptions.FailoverPartnerSPN); ResolveExtendedServerName(primaryServerInfo, !redirectedUserInstance, connectionOptions); - if (ServerProvidedFailOverPartner == null) + if (ServerProvidedFailoverPartner == null) { ResolveExtendedServerName(failoverServerInfo, !redirectedUserInstance && failoverHost != primaryServerInfo.UserServerName, connectionOptions); } @@ -1861,12 +1859,21 @@ TimeoutTimer timeout failoverDemandDone = true; } - // Primary server may give us a different failover partner than the connection string indicates. Update it - if (ServerProvidedFailOverPartner != null && failoverServerInfo.ResolvedServerName != ServerProvidedFailOverPartner) + // Primary server may give us a different failover partner than the connection string indicates. + // Update it only if we are respecting server-provided failover partner values. + if (ServerProvidedFailoverPartner != null && failoverServerInfo.ResolvedServerName != ServerProvidedFailoverPartner) { - SqlClientEventSource.Log.TryAdvancedTraceEvent(" {0}, new failover partner={1}", ObjectID, ServerProvidedFailOverPartner); - failoverServerInfo.SetDerivedNames(string.Empty, ServerProvidedFailOverPartner); + if (LocalAppContextSwitches.IgnoreServerProvidedFailoverPartner) + { + SqlClientEventSource.Log.TryTraceEvent(" {0}, Ignoring server provided failover partner '{1}' due to IgnoreServerProvidedFailoverPartner AppContext switch.", ObjectID, ServerProvidedFailoverPartner); + } + else + { + SqlClientEventSource.Log.TryAdvancedTraceEvent(" {0}, new failover partner={1}", ObjectID, ServerProvidedFailoverPartner); + failoverServerInfo.SetDerivedNames(string.Empty, ServerProvidedFailoverPartner); + } } + currentServerInfo = failoverServerInfo; _timeoutErrorInternal.SetInternalSourceType(SqlConnectionInternalSourceType.Failover); } @@ -1916,7 +1923,7 @@ TimeoutTimer timeout _currentPacketSize = connectionOptions.PacketSize; _currentLanguage = _originalLanguage = ConnectionOptions.CurrentLanguage; CurrentDatabase = _originalDatabase = connectionOptions.InitialCatalog; - _currentFailoverPartner = null; + ServerProvidedFailoverPartner = null; _instanceName = string.Empty; AttemptOneLogin( @@ -1978,7 +1985,7 @@ TimeoutTimer timeout _activeDirectoryAuthTimeoutRetryHelper.State = ActiveDirectoryAuthenticationTimeoutRetryState.HasLoggedIn; // if connected to failover host, but said host doesn't have DbMirroring set up, throw an error - if (useFailoverHost && ServerProvidedFailOverPartner == null) + if (useFailoverHost && ServerProvidedFailoverPartner == null) { throw SQL.InvalidPartnerConfiguration(failoverHost, CurrentDatabase); } @@ -1987,8 +1994,13 @@ TimeoutTimer timeout { // We must wait for CompleteLogin to finish for to have the // env change from the server to know its designated failover - // partner; save this information in _currentFailoverPartner. - PoolGroupProviderInfo.FailoverCheck(useFailoverHost, connectionOptions, ServerProvidedFailOverPartner); + // partner. + + // When ignoring server provided failover partner, we must pass in the original failover partner from the connection string. + // Otherwise the pool group's failover partner designation will be updated to point to the server provided value. + string actualFailoverPartner = LocalAppContextSwitches.IgnoreServerProvidedFailoverPartner ? failoverHost : ServerProvidedFailoverPartner; + + PoolGroupProviderInfo.FailoverCheck(useFailoverHost, connectionOptions, actualFailoverPartner); } CurrentDataSource = (useFailoverHost ? failoverHost : primaryServerInfo.UserServerName); } @@ -2218,7 +2230,8 @@ internal void OnEnvChange(SqlEnvChange rec) { throw SQL.ROR_FailoverNotSupportedServer(this); } - _currentFailoverPartner = rec._newValue; + + ServerProvidedFailoverPartner = rec._newValue; break; case TdsEnums.ENV_PROMOTETRANSACTION: diff --git a/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/SqlCommand.netfx.cs b/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/SqlCommand.netfx.cs index 24c02fd0ae..61bf90e349 100644 --- a/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/SqlCommand.netfx.cs +++ b/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/SqlCommand.netfx.cs @@ -2205,7 +2205,7 @@ private void CheckNotificationStateAndAutoEnlist() } Notification.Options = SqlDependency.GetDefaultComposedOptions(_activeConnection.DataSource, - InternalTdsConnection.ServerProvidedFailOverPartner, + InternalTdsConnection.ServerProvidedFailoverPartner, identityUserName, _activeConnection.Database); } diff --git a/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/SqlInternalConnectionTds.cs b/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/SqlInternalConnectionTds.cs index 66dd0e7340..9df86e4fd2 100644 --- a/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/SqlInternalConnectionTds.cs +++ b/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/SqlInternalConnectionTds.cs @@ -326,7 +326,6 @@ internal SessionData CurrentSessionData // FOR CONNECTION RESET MANAGEMENT private bool _fResetConnection; private string _originalDatabase; - private string _currentFailoverPartner; // only set by ENV change from server private string _originalLanguage; private string _currentLanguage; private int _currentPacketSize; @@ -714,13 +713,7 @@ internal TdsParser Parser } } - internal string ServerProvidedFailOverPartner - { - get - { - return _currentFailoverPartner; - } - } + internal string ServerProvidedFailoverPartner { get; set; } internal SqlConnectionPoolGroupProviderInfo PoolGroupProviderInfo { @@ -1515,7 +1508,7 @@ private void OpenLoginEnlist(TimeoutTimer timeout, throw SQL.ROR_FailoverNotSupportedConnString(); } - if (ServerProvidedFailOverPartner != null) + if (ServerProvidedFailoverPartner != null) { throw SQL.ROR_FailoverNotSupportedServer(this); } @@ -1665,7 +1658,7 @@ private void LoginNoFailover(ServerInfo serverInfo, isFirstTransparentAttempt: isFirstTransparentAttempt, disableTnir: disableTnir); - if (connectionOptions.MultiSubnetFailover && ServerProvidedFailOverPartner != null) + if (connectionOptions.MultiSubnetFailover && ServerProvidedFailoverPartner != null) { // connection succeeded: trigger exception if server sends failover partner and MultiSubnetFailover is used throw SQL.MultiSubnetFailoverWithFailoverPartner(serverProvidedFailoverPartner: true, internalConnection: this); @@ -1693,7 +1686,7 @@ private void LoginNoFailover(ServerInfo serverInfo, _currentPacketSize = ConnectionOptions.PacketSize; _currentLanguage = _originalLanguage = ConnectionOptions.CurrentLanguage; CurrentDatabase = _originalDatabase = ConnectionOptions.InitialCatalog; - _currentFailoverPartner = null; + ServerProvidedFailoverPartner = null; _instanceName = string.Empty; routingAttempts++; @@ -1735,7 +1728,7 @@ private void LoginNoFailover(ServerInfo serverInfo, // We only get here when we failed to connect, but are going to re-try // Switch to failover logic if the server provided a partner - if (ServerProvidedFailOverPartner != null) + if (ServerProvidedFailoverPartner != null) { if (connectionOptions.MultiSubnetFailover) { @@ -1751,7 +1744,7 @@ private void LoginNoFailover(ServerInfo serverInfo, LoginWithFailover( true, // start by using failover partner, since we already failed to connect to the primary serverInfo, - ServerProvidedFailOverPartner, + ServerProvidedFailoverPartner, newPassword, newSecurePassword, redirectedUserInstance, @@ -1773,8 +1766,13 @@ private void LoginNoFailover(ServerInfo serverInfo, { // We must wait for CompleteLogin to finish for to have the // env change from the server to know its designated failover - // partner; save this information in _currentFailoverPartner. - PoolGroupProviderInfo.FailoverCheck(false, connectionOptions, ServerProvidedFailOverPartner); + // partner; save this information in ServerProvidedFailoverPartner. + + // When ignoring server provided failover partner, we must pass in the original failover partner from the connection string. + // Otherwise the pool group's failover partner designation will be updated to point to the server provided value. + string actualFailoverPartner = LocalAppContextSwitches.IgnoreServerProvidedFailoverPartner ? "" : ServerProvidedFailoverPartner; + + PoolGroupProviderInfo.FailoverCheck(false, connectionOptions, actualFailoverPartner); } CurrentDataSource = originalServerInfo.UserServerName; } @@ -1858,7 +1856,7 @@ TimeoutTimer timeout ServerInfo failoverServerInfo = new ServerInfo(connectionOptions, failoverHost, connectionOptions.FailoverPartnerSPN); ResolveExtendedServerName(primaryServerInfo, !redirectedUserInstance, connectionOptions); - if (ServerProvidedFailOverPartner == null) + if (ServerProvidedFailoverPartner == null) { ResolveExtendedServerName(failoverServerInfo, !redirectedUserInstance && failoverHost != primaryServerInfo.UserServerName, connectionOptions); } @@ -1915,12 +1913,21 @@ TimeoutTimer timeout failoverDemandDone = true; } - // Primary server may give us a different failover partner than the connection string indicates. Update it - if (ServerProvidedFailOverPartner != null && failoverServerInfo.ResolvedServerName != ServerProvidedFailOverPartner) + // Primary server may give us a different failover partner than the connection string indicates. + // Update it only if we are respecting server-provided failover partner values. + if (ServerProvidedFailoverPartner != null && failoverServerInfo.ResolvedServerName != ServerProvidedFailoverPartner) { - SqlClientEventSource.Log.TryAdvancedTraceEvent(" {0}, new failover partner={1}", ObjectID, ServerProvidedFailOverPartner); - failoverServerInfo.SetDerivedNames(protocol, ServerProvidedFailOverPartner); + if (LocalAppContextSwitches.IgnoreServerProvidedFailoverPartner) + { + SqlClientEventSource.Log.TryTraceEvent(" {0}, Ignoring server provided failover partner '{1}' due to IgnoreServerProvidedFailoverPartner AppContext switch.", ObjectID, ServerProvidedFailoverPartner); + } + else + { + SqlClientEventSource.Log.TryAdvancedTraceEvent(" {0}, new failover partner={1}", ObjectID, ServerProvidedFailoverPartner); + failoverServerInfo.SetDerivedNames(protocol, ServerProvidedFailoverPartner); + } } + currentServerInfo = failoverServerInfo; _timeoutErrorInternal.SetInternalSourceType(SqlConnectionInternalSourceType.Failover); } @@ -1970,7 +1977,7 @@ TimeoutTimer timeout _currentPacketSize = connectionOptions.PacketSize; _currentLanguage = _originalLanguage = ConnectionOptions.CurrentLanguage; CurrentDatabase = _originalDatabase = connectionOptions.InitialCatalog; - _currentFailoverPartner = null; + ServerProvidedFailoverPartner = null; _instanceName = string.Empty; AttemptOneLogin( @@ -2035,7 +2042,7 @@ TimeoutTimer timeout _activeDirectoryAuthTimeoutRetryHelper.State = ActiveDirectoryAuthenticationTimeoutRetryState.HasLoggedIn; // if connected to failover host, but said host doesn't have DbMirroring set up, throw an error - if (useFailoverHost && ServerProvidedFailOverPartner == null) + if (useFailoverHost && ServerProvidedFailoverPartner == null) { throw SQL.InvalidPartnerConfiguration(failoverHost, CurrentDatabase); } @@ -2044,8 +2051,13 @@ TimeoutTimer timeout { // We must wait for CompleteLogin to finish for to have the // env change from the server to know its designated failover - // partner; save this information in _currentFailoverPartner. - PoolGroupProviderInfo.FailoverCheck(useFailoverHost, connectionOptions, ServerProvidedFailOverPartner); + // partner. + + // When ignoring server provided failover partner, we must pass in the original failover partner from the connection string. + // Otherwise the pool group's failover partner designation will be updated to point to the server provided value. + string actualFailoverPartner = LocalAppContextSwitches.IgnoreServerProvidedFailoverPartner ? failoverHost : ServerProvidedFailoverPartner; + + PoolGroupProviderInfo.FailoverCheck(useFailoverHost, connectionOptions, actualFailoverPartner); } CurrentDataSource = (useFailoverHost ? failoverHost : primaryServerInfo.UserServerName); } @@ -2273,7 +2285,7 @@ internal void OnEnvChange(SqlEnvChange rec) break; case TdsEnums.ENV_LOGSHIPNODE: - _currentFailoverPartner = rec._newValue; + ServerProvidedFailoverPartner = rec._newValue; break; case TdsEnums.ENV_PROMOTETRANSACTION: diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/LocalAppContextSwitches.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/LocalAppContextSwitches.cs index f546b0436b..31b5e66b98 100644 --- a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/LocalAppContextSwitches.cs +++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/LocalAppContextSwitches.cs @@ -24,6 +24,7 @@ private enum Tristate : byte private const string UseCompatibilityAsyncBehaviourString = @"Switch.Microsoft.Data.SqlClient.UseCompatibilityAsyncBehaviour"; private const string UseConnectionPoolV2String = @"Switch.Microsoft.Data.SqlClient.UseConnectionPoolV2"; private const string TruncateScaledDecimalString = @"Switch.Microsoft.Data.SqlClient.TruncateScaledDecimal"; + private const string IgnoreServerProvidedFailoverPartnerString = @"Switch.Microsoft.Data.SqlClient.IgnoreServerProvidedFailoverPartner"; #if NET private const string GlobalizationInvariantModeString = @"System.Globalization.Invariant"; private const string GlobalizationInvariantModeEnvironmentVariable = "DOTNET_SYSTEM_GLOBALIZATION_INVARIANT"; @@ -43,6 +44,7 @@ private enum Tristate : byte private static Tristate s_useCompatibilityAsyncBehaviour; private static Tristate s_useConnectionPoolV2; private static Tristate s_truncateScaledDecimal; + private static Tristate s_ignoreServerProvidedFailoverPartner; #if NET private static Tristate s_globalizationInvariantMode; private static Tristate s_useManagedNetworking; @@ -231,7 +233,7 @@ public static bool UseMinimumLoginTimeout /// When set to 'true' this will output a scale value of 7 (DEFAULT_VARTIME_SCALE) when the scale /// is explicitly set to zero for VarTime data types ('datetime2', 'datetimeoffset' and 'time') /// If no scale is set explicitly it will continue to output scale of 7 (DEFAULT_VARTIME_SCALE) - /// regardsless of switch value. + /// regardless of switch value. /// This app context switch defaults to 'true'. /// public static bool LegacyVarTimeZeroScaleBehaviour @@ -299,6 +301,34 @@ public static bool TruncateScaledDecimal } } + /// + /// When set to true, the failover partner provided by the server during connection + /// will be ignored. This is useful in scenarios where the application wants to + /// control the failover behavior explicitly (e.g. using a custom port). The application + /// must be kept up to date with the failover configuration of the server. + /// The application will not automatically discover a newly configured failover partner. + /// + /// This app context switch defaults to 'false'. + /// + public static bool IgnoreServerProvidedFailoverPartner + { + get + { + if (s_ignoreServerProvidedFailoverPartner == Tristate.NotInitialized) + { + if (AppContext.TryGetSwitch(IgnoreServerProvidedFailoverPartnerString, out bool returnedValue) && returnedValue) + { + s_ignoreServerProvidedFailoverPartner = Tristate.True; + } + else + { + s_ignoreServerProvidedFailoverPartner = Tristate.False; + } + } + return s_ignoreServerProvidedFailoverPartner == Tristate.True; + } + } + #if NET /// /// .NET Core 2.0 and up supports Globalization Invariant mode, which reduces the size of the required libraries for diff --git a/src/Microsoft.Data.SqlClient/tests/Common/LocalAppContextSwitchesHelper.cs b/src/Microsoft.Data.SqlClient/tests/Common/LocalAppContextSwitchesHelper.cs index df68f86677..2dea4ce022 100644 --- a/src/Microsoft.Data.SqlClient/tests/Common/LocalAppContextSwitchesHelper.cs +++ b/src/Microsoft.Data.SqlClient/tests/Common/LocalAppContextSwitchesHelper.cs @@ -31,7 +31,8 @@ public sealed class LocalAppContextSwitchesHelper : IDisposable private readonly PropertyInfo _useCompatibilityAsyncBehaviourProperty; private readonly PropertyInfo _useConnectionPoolV2Property; private readonly PropertyInfo _truncateScaledDecimalProperty; - #if NET + private readonly PropertyInfo _ignoreServerProvidedFailoverPartner; +#if NET private readonly PropertyInfo _globalizationInvariantModeProperty; private readonly PropertyInfo _useManagedNetworkingProperty; #else @@ -57,7 +58,9 @@ public sealed class LocalAppContextSwitchesHelper : IDisposable private readonly Tristate _useConnectionPoolV2Original; private readonly FieldInfo _truncateScaledDecimalField; private readonly Tristate _truncateScaledDecimalOriginal; - #if NET + private readonly FieldInfo _ignoreServerProvidedFailoverPartnerField; + private readonly Tristate _ignoreServerProvidedFailoverPartnerOriginal; +#if NET private readonly FieldInfo _globalizationInvariantModeField; private readonly Tristate _globalizationInvariantModeOriginal; private readonly FieldInfo _useManagedNetworkingField; @@ -155,6 +158,10 @@ void InitProperty(string name, out PropertyInfo property) "TruncateScaledDecimal", out _truncateScaledDecimalProperty); + InitProperty( + "IgnoreServerProvidedFailoverPartner", + out _ignoreServerProvidedFailoverPartner); + #if NET InitProperty( "GlobalizationInvariantMode", @@ -229,7 +236,12 @@ void InitField(string name, out FieldInfo field, out Tristate value) out _truncateScaledDecimalField, out _truncateScaledDecimalOriginal); - #if NET + InitField( + "s_ignoreServerProvidedFailoverPartner", + out _ignoreServerProvidedFailoverPartnerField, + out _ignoreServerProvidedFailoverPartnerOriginal); + +#if NET InitField( "s_globalizationInvariantMode", out _globalizationInvariantModeField, @@ -307,7 +319,11 @@ void RestoreField(FieldInfo field, Tristate value) _truncateScaledDecimalField, _truncateScaledDecimalOriginal); - #if NET + RestoreField( + _ignoreServerProvidedFailoverPartnerField, + _ignoreServerProvidedFailoverPartnerOriginal); + +#if NET RestoreField( _globalizationInvariantModeField, _globalizationInvariantModeOriginal); @@ -408,7 +424,12 @@ public bool TruncateScaledDecimal get => (bool)_truncateScaledDecimalProperty.GetValue(null); } - #if NET + public bool IgnoreServerProvidedFailoverPartner + { + get => (bool)_ignoreServerProvidedFailoverPartner.GetValue(null); + } + +#if NET /// /// Access the LocalAppContextSwitches.GlobalizationInvariantMode property. /// @@ -526,7 +547,13 @@ public Tristate TruncateScaledDecimalField set => SetValue(_truncateScaledDecimalField, value); } - #if NET + public Tristate IgnoreServerProvidedFailoverPartnerField + { + get => GetValue(_ignoreServerProvidedFailoverPartnerField); + set => SetValue(_ignoreServerProvidedFailoverPartnerField, value); + } + +#if NET /// /// Get or set the LocalAppContextSwitches.GlobalizationInvariantMode switch value. /// diff --git a/src/Microsoft.Data.SqlClient/tests/UnitTests/SimulatedServerTests/ConnectionFailoverTests.cs b/src/Microsoft.Data.SqlClient/tests/UnitTests/SimulatedServerTests/ConnectionFailoverTests.cs index 3fa98d1e18..dfc37d2720 100644 --- a/src/Microsoft.Data.SqlClient/tests/UnitTests/SimulatedServerTests/ConnectionFailoverTests.cs +++ b/src/Microsoft.Data.SqlClient/tests/UnitTests/SimulatedServerTests/ConnectionFailoverTests.cs @@ -4,6 +4,8 @@ using System; using System.Data; +using Microsoft.Data.SqlClient; +using Microsoft.Data.SqlClient.Tests.Common; using Microsoft.SqlServer.TDS.Servers; using Xunit; @@ -345,7 +347,7 @@ public void TransientFault_ShouldConnectToPrimary(uint errorCode) new TdsServerArguments { // Doesn't need to point to a real endpoint, just needs a value specified - FailoverPartner = "localhost:1234", + FailoverPartner = "localhost,1234", }); failoverServer.Start(); @@ -354,7 +356,7 @@ public void TransientFault_ShouldConnectToPrimary(uint errorCode) { IsEnabledTransientError = true, Number = errorCode, - FailoverPartner = $"localhost:{failoverServer.EndPoint.Port}", + FailoverPartner = $"localhost,{failoverServer.EndPoint.Port}", }); server.Start(); @@ -521,5 +523,68 @@ public void TransientFault_WithUserProvidedPartner_RetryDisabled_ShouldFail(uint Assert.Fail(); } + + [Fact] + public void TransientFault_IgnoreServerProvidedFailoverPartner_ShouldConnectToUserProvidedPartner() + { + // Arrange + using LocalAppContextSwitchesHelper switchesHelper = new(); + switchesHelper.IgnoreServerProvidedFailoverPartnerField = LocalAppContextSwitchesHelper.Tristate.True; + + using TdsServer failoverServer = new( + new TdsServerArguments + { + // Doesn't need to point to a real endpoint, just needs a value specified + FailoverPartner = "localhost,1234", + }); + failoverServer.Start(); + + using TdsServer server = new( + new TdsServerArguments() + { + // Set an invalid failover partner to ensure that the connection fails if the + // server provided failover partner is used. + FailoverPartner = $"invalidhost", + }); + server.Start(); + + SqlConnectionStringBuilder builder = new() + { + DataSource = $"localhost,{server.EndPoint.Port}", + InitialCatalog = "master", + Encrypt = false, + FailoverPartner = $"localhost,{failoverServer.EndPoint.Port}", + // Ensure pooling is enabled so that the failover partner information + // is persisted in the pool group. If pooling is disabled, the server + // provided failover partner will never be used. + Pooling = true + }; + SqlConnection connection = new(builder.ConnectionString); + + // Connect once to the primary to trigger it to send the failover partner + connection.Open(); + Assert.Equal("invalidhost", (connection.InnerConnection as SqlInternalConnectionTds)!.ServerProvidedFailoverPartner); + + // Close the connection to return it to the pool + connection.Close(); + + + // Act + // Dispose of the server to trigger a failover + server.Dispose(); + + // Opening a new connection will use the failover partner stored in the pool group. + // This will fail if the server provided failover partner was stored to the pool group. + using SqlConnection failoverConnection = new(builder.ConnectionString); + failoverConnection.Open(); + + // Assert + Assert.Equal(ConnectionState.Open, failoverConnection.State); + Assert.Equal($"localhost,{failoverServer.EndPoint.Port}", failoverConnection.DataSource); + // 1 for the initial connection + Assert.Equal(1, server.PreLoginCount); + // 1 for the failover connection + Assert.Equal(1, failoverServer.PreLoginCount); + } } } diff --git a/src/Microsoft.Data.SqlClient/tests/tools/TDS/TDS.Servers/TransientTdsErrorTdsServer.cs b/src/Microsoft.Data.SqlClient/tests/tools/TDS/TDS.Servers/TransientTdsErrorTdsServer.cs index e5d2e52100..ecd89f5812 100644 --- a/src/Microsoft.Data.SqlClient/tests/tools/TDS/TDS.Servers/TransientTdsErrorTdsServer.cs +++ b/src/Microsoft.Data.SqlClient/tests/tools/TDS/TDS.Servers/TransientTdsErrorTdsServer.cs @@ -3,6 +3,7 @@ // See the LICENSE file in the project root for more information. using System; +using System.Threading; using Microsoft.SqlServer.TDS.Done; using Microsoft.SqlServer.TDS.EndPoint; using Microsoft.SqlServer.TDS.Error; @@ -59,38 +60,54 @@ public override TDSMessageCollection OnLogin7Request(ITDSServerSession session, // Check if we're still going to raise transient error if (Arguments.IsEnabledTransientError && RequestCounter < Arguments.RepeatCount) { - uint errorNumber = Arguments.Number; - string errorMessage = Arguments.Message ?? GetErrorMessage(errorNumber); + return GenerateErrorMessage(request); + } + + // Return login response from the base class + return base.OnLogin7Request(session, request); + } + + /// + public override TDSMessageCollection OnSQLBatchRequest(ITDSServerSession session, TDSMessage message) + { + if (Arguments.IsEnabledTransientError && RequestCounter < Arguments.RepeatCount) + { + return GenerateErrorMessage(message); + } + + return base.OnSQLBatchRequest(session, message); + } - // Log request to which we're about to send a failure - TDSUtilities.Log(Arguments.Log, "Request", loginRequest); + private TDSMessageCollection GenerateErrorMessage(TDSMessage request) + { + uint errorNumber = Arguments.Number; + string errorMessage = Arguments.Message ?? GetErrorMessage(errorNumber); - // Prepare ERROR token with the denial details - TDSErrorToken errorToken = new TDSErrorToken(errorNumber, 1, 20, errorMessage); + // Log request to which we're about to send a failure + TDSUtilities.Log(Arguments.Log, "Request", request); - // Log response - TDSUtilities.Log(Arguments.Log, "Response", errorToken); + // Prepare ERROR token with the denial details + TDSErrorToken errorToken = new TDSErrorToken(errorNumber, 1, 20, errorMessage); - // Serialize the error token into the response packet - TDSMessage responseMessage = new TDSMessage(TDSMessageType.Response, errorToken); + // Log response + TDSUtilities.Log(Arguments.Log, "Response", errorToken); - // Create DONE token - TDSDoneToken doneToken = new TDSDoneToken(TDSDoneTokenStatusType.Final | TDSDoneTokenStatusType.Error); + // Serialize the error token into the response packet + TDSMessage responseMessage = new TDSMessage(TDSMessageType.Response, errorToken); - // Log response - TDSUtilities.Log(Arguments.Log, "Response", doneToken); + // Create DONE token + TDSDoneToken doneToken = new TDSDoneToken(TDSDoneTokenStatusType.Final | TDSDoneTokenStatusType.Error); - // Serialize DONE token into the response packet - responseMessage.Add(doneToken); + // Log response + TDSUtilities.Log(Arguments.Log, "Response", doneToken); - RequestCounter++; + // Serialize DONE token into the response packet + responseMessage.Add(doneToken); - // Put a single message into the collection and return it - return new TDSMessageCollection(responseMessage); - } + RequestCounter++; - // Return login response from the base class - return base.OnLogin7Request(session, request); + // Put a single message into the collection and return it + return new TDSMessageCollection(responseMessage); } public override void Dispose() {