Skip to content

Commit 1ec7675

Browse files
Extract SqlBulkCopy column names using dynamic SQL (#4092)
Co-authored-by: Edward Neal <55035479+edwardneal@users.noreply.github.com>
1 parent f8183c1 commit 1ec7675

7 files changed

Lines changed: 103 additions & 43 deletions

File tree

src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlBulkCopy.cs

Lines changed: 50 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -472,29 +472,51 @@ private string CreateInitialQuery()
472472
}
473473
else if (!string.IsNullOrEmpty(CatalogName))
474474
{
475-
CatalogName = SqlServerEscapeHelper.EscapeIdentifier(CatalogName);
475+
CatalogName = SqlServerEscapeHelper.EscapeStringAsLiteral(SqlServerEscapeHelper.EscapeIdentifier(CatalogName));
476476
}
477477

478478
string objectName = ADP.BuildMultiPartName(parts);
479479
string escapedObjectName = SqlServerEscapeHelper.EscapeStringAsLiteral(objectName);
480-
// Specify the column names explicitly. This is to ensure that we can map to hidden columns (e.g. columns in temporal tables.)
481-
// If the target table doesn't exist, OBJECT_ID will return NULL and @Column_Names will remain non-null. The subsequent SELECT *
482-
// query will then continue to fail with "Invalid object name" rather than with an unusual error because the query being executed
483-
// is NULL.
484-
// Some hidden columns (e.g. SQL Graph columns) cannot be selected, so we need to exclude them explicitly.
480+
// Specify the column names explicitly. This is to ensure that we can map to hidden
481+
// columns (e.g. columns in temporal tables.) If the target table doesn't exist,
482+
// OBJECT_ID will return NULL and @Column_Names will remain non-null. The subsequent
483+
// SELECT * query will then continue to fail with "Invalid object name" rather than with
484+
// an unusual error because the query being executed is NULL.
485+
//
486+
// Some hidden columns (e.g. SQL Graph columns) cannot be selected, so we need to
487+
// exclude them explicitly. The graph_type values excluded below are internal graph
488+
// columns that cannot be selected directly:
489+
//
490+
// 1 = GRAPH_ID
491+
// 3 = GRAPH_FROM_ID
492+
// 4 = GRAPH_FROM_OBJ_ID
493+
// 6 = GRAPH_TO_ID
494+
// 7 = GRAPH_TO_OBJ_ID
495+
//
496+
// See: https://learn.microsoft.com/sql/relational-databases/graphs/sql-graph-architecture#syscolumns
497+
//
498+
// The column-name query is built as dynamic SQL and executed via sp_executesql so
499+
// that it is not compiled (and rejected) on SQL Server versions that lack the
500+
// graph_type column (e.g. SQL 2016). CatalogName and escapedObjectName are
501+
// interpolated directly into the SQL string because SQL Server does not allow
502+
// identifiers (database/schema/table names) to be passed as parameters. Both
503+
// values are escaped via SqlServerEscapeHelper before interpolation.
485504
return $"""
486505
SELECT @@TRANCOUNT;
487506
507+
DECLARE @Object_ID INT = OBJECT_ID('{escapedObjectName}');
508+
DECLARE @Column_Name_Query NVARCHAR(MAX);
488509
DECLARE @Column_Names NVARCHAR(MAX) = NULL;
489510
IF EXISTS (SELECT TOP 1 * FROM sys.all_columns WHERE [object_id] = OBJECT_ID('sys.all_columns') AND [name] = 'graph_type')
490511
BEGIN
491-
SELECT @Column_Names = COALESCE(@Column_Names + ', ', '') + QUOTENAME([name]) FROM {CatalogName}.[sys].[all_columns] WHERE [object_id] = OBJECT_ID('{escapedObjectName}') AND COALESCE([graph_type], 0) NOT IN (1, 3, 4, 6, 7) ORDER BY [column_id] ASC;
512+
SET @Column_Name_Query = N'SELECT @Column_Names = COALESCE(@Column_Names + '', '', '''') + QUOTENAME([name]) FROM {CatalogName}.[sys].[all_columns] WHERE [object_id] = @Object_ID AND COALESCE([graph_type], 0) NOT IN (1, 3, 4, 6, 7) ORDER BY [column_id] ASC;';
492513
END
493514
ELSE
494515
BEGIN
495-
SELECT @Column_Names = COALESCE(@Column_Names + ', ', '') + QUOTENAME([name]) FROM {CatalogName}.[sys].[all_columns] WHERE [object_id] = OBJECT_ID('{escapedObjectName}') ORDER BY [column_id] ASC;
516+
SET @Column_Name_Query = N'SELECT @Column_Names = COALESCE(@Column_Names + '', '', '''') + QUOTENAME([name]) FROM {CatalogName}.[sys].[all_columns] WHERE [object_id] = @Object_ID ORDER BY [column_id] ASC;';
496517
END
497518
519+
EXEC sp_executesql @Column_Name_Query, N'@Object_ID INT, @Column_Names NVARCHAR(MAX) OUTPUT', @Object_ID = @Object_ID, @Column_Names = @Column_Names OUTPUT;
498520
SELECT @Column_Names = COALESCE(@Column_Names, '*');
499521
500522
SET FMTONLY ON;
@@ -624,7 +646,7 @@ private string AnalyzeTargetAndCreateUpdateBulkCommand(BulkCopySimpleResultSet i
624646

625647
bool matched = false;
626648
bool rejected = false;
627-
649+
628650
// Look for a local match for the remote column.
629651
for (int j = 0; j < _localColumnMappings.Count; ++j)
630652
{
@@ -644,7 +666,7 @@ private string AnalyzeTargetAndCreateUpdateBulkCommand(BulkCopySimpleResultSet i
644666

645667
// Remove it from our unmatched set.
646668
unmatchedColumns.Remove(localColumn.DestinationColumn);
647-
669+
648670
// Check for column types that we refuse to bulk load, even
649671
// though we found a match.
650672
//
@@ -1437,7 +1459,7 @@ private void RunParserReliably(BulkCopySimpleResultSet bulkCopyHandler = null)
14371459
try
14381460
{
14391461
// @TODO: CER Exception Handling was removed here (see GH#3581)
1440-
_parser.Run(RunBehavior.UntilDone, null, null, bulkCopyHandler, _stateObj);
1462+
_parser.Run(RunBehavior.UntilDone, null, null, bulkCopyHandler, _stateObj);
14411463
}
14421464
finally
14431465
{
@@ -1760,7 +1782,7 @@ public void WriteToServer(DbDataReader reader)
17601782
try
17611783
{
17621784
statistics = SqlStatistics.StartTimer(Statistics);
1763-
1785+
17641786
ResetWriteToServerGlobalVariables();
17651787
_rowSource = reader;
17661788
_dbDataReaderRowSource = reader;
@@ -1796,13 +1818,13 @@ public void WriteToServer(IDataReader reader)
17961818
try
17971819
{
17981820
statistics = SqlStatistics.StartTimer(Statistics);
1799-
1821+
18001822
ResetWriteToServerGlobalVariables();
18011823
_rowSource = reader;
18021824
_sqlDataReaderRowSource = _rowSource as SqlDataReader;
18031825
_dbDataReaderRowSource = _rowSource as DbDataReader;
18041826
_rowSourceType = ValueSourceType.IDataReader;
1805-
1827+
18061828
WriteRowSourceToServerAsync(reader.FieldCount, CancellationToken.None); //It returns null since _isAsyncBulkCopy = false;
18071829
}
18081830
finally
@@ -1918,7 +1940,7 @@ public Task WriteToServerAsync(DataRow[] rows, CancellationToken cancellationTok
19181940
try
19191941
{
19201942
statistics = SqlStatistics.StartTimer(Statistics);
1921-
1943+
19221944
ResetWriteToServerGlobalVariables();
19231945
if (rows.Length == 0)
19241946
{
@@ -1935,9 +1957,9 @@ public Task WriteToServerAsync(DataRow[] rows, CancellationToken cancellationTok
19351957
_rowSourceType = ValueSourceType.RowArray;
19361958
_rowEnumerator = rows.GetEnumerator();
19371959
_isAsyncBulkCopy = true;
1938-
1960+
19391961
// It returns Task since _isAsyncBulkCopy = true;
1940-
return WriteRowSourceToServerAsync(table.Columns.Count, cancellationToken);
1962+
return WriteRowSourceToServerAsync(table.Columns.Count, cancellationToken);
19411963
}
19421964
finally
19431965
{
@@ -1964,19 +1986,19 @@ public Task WriteToServerAsync(DbDataReader reader, CancellationToken cancellati
19641986
{
19651987
throw SQL.BulkLoadPendingOperation();
19661988
}
1967-
1989+
19681990
SqlStatistics statistics = Statistics;
19691991
try
19701992
{
19711993
statistics = SqlStatistics.StartTimer(Statistics);
1972-
1994+
19731995
ResetWriteToServerGlobalVariables();
19741996
_rowSource = reader;
19751997
_sqlDataReaderRowSource = reader as SqlDataReader;
19761998
_dbDataReaderRowSource = reader;
19771999
_rowSourceType = ValueSourceType.DbDataReader;
19782000
_isAsyncBulkCopy = true;
1979-
2001+
19802002
// It returns Task since _isAsyncBulkCopy = true;
19812003
return WriteRowSourceToServerAsync(reader.FieldCount, cancellationToken);
19822004
}
@@ -2016,7 +2038,7 @@ public Task WriteToServerAsync(IDataReader reader, CancellationToken cancellatio
20162038
_dbDataReaderRowSource = _rowSource as DbDataReader;
20172039
_rowSourceType = ValueSourceType.IDataReader;
20182040
_isAsyncBulkCopy = true;
2019-
2041+
20202042
// It returns Task since _isAsyncBulkCopy = true;
20212043
return WriteRowSourceToServerAsync(reader.FieldCount, cancellationToken);
20222044
}
@@ -2056,15 +2078,15 @@ public Task WriteToServerAsync(DataTable table, DataRowState rowState, Cancellat
20562078
try
20572079
{
20582080
statistics = SqlStatistics.StartTimer(Statistics);
2059-
2081+
20602082
ResetWriteToServerGlobalVariables();
20612083
_rowStateToSkip = ((rowState == 0) || (rowState == DataRowState.Deleted)) ? DataRowState.Deleted : ~rowState | DataRowState.Deleted;
20622084
_rowSource = table;
20632085
_dataTableSource = table;
20642086
_rowSourceType = ValueSourceType.DataTable;
20652087
_rowEnumerator = table.Rows.GetEnumerator();
20662088
_isAsyncBulkCopy = true;
2067-
2089+
20682090
// It returns Task since _isAsyncBulkCopy = true;
20692091
return WriteRowSourceToServerAsync(table.Columns.Count, cancellationToken);
20702092
}
@@ -2114,7 +2136,7 @@ private Task WriteRowSourceToServerAsync(int columnCount, CancellationToken ctok
21142136

21152137
bool finishedSynchronously = true;
21162138
_isBulkCopyingInProgress = true;
2117-
2139+
21182140
CreateOrValidateConnection(nameof(WriteToServer));
21192141

21202142
SqlConnectionInternal internalConnection = _connection.GetOpenTdsConnection();
@@ -3065,11 +3087,11 @@ private void WriteToServerInternalRestAsync(CancellationToken cts, TaskCompletio
30653087

30663088
// No need to cancel timer since SqlBulkCopy creates specific task source for reconnection.
30673089
AsyncHelper.SetTimeoutExceptionWithState(
3068-
completion: cancellableReconnectTS,
3090+
completion: cancellableReconnectTS,
30693091
timeout: BulkCopyTimeout,
30703092
state: _destinationTableName,
3071-
onFailure: static state =>
3072-
SQL.BulkLoadInvalidDestinationTable((string)state, SQL.CR_ReconnectTimeout()),
3093+
onFailure: static state =>
3094+
SQL.BulkLoadInvalidDestinationTable((string)state, SQL.CR_ReconnectTimeout()),
30733095
cancellationToken: CancellationToken.None
30743096
);
30753097

@@ -3242,7 +3264,7 @@ private Task WriteToServerInternalAsync(CancellationToken ctoken)
32423264
}
32433265
return resultTask;
32443266
}
3245-
3267+
32463268
private void ResetWriteToServerGlobalVariables()
32473269
{
32483270
_dataTableSource = null;

src/Microsoft.Data.SqlClient/tests/ManualTests/DataCommon/DataTestUtility.cs

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ public static class DataTestUtility
8686
internal static readonly string KerberosDomainPassword = null;
8787

8888
// SQL server Version
89-
private static string s_sQLServerVersion = string.Empty;
89+
private static string s_sqlServerVersion;
9090

9191
//SQL Server EngineEdition
9292
private static string s_sqlServerEngineEdition;
@@ -125,9 +125,9 @@ public static string SQLServerVersion
125125
{
126126
if (!string.IsNullOrEmpty(TCPConnectionString))
127127
{
128-
s_sQLServerVersion ??= GetSqlServerProperty(TCPConnectionString, ServerProperty.ProductMajorVersion);
128+
s_sqlServerVersion ??= GetSqlServerProperty(TCPConnectionString, ServerProperty.ProductMajorVersion);
129129
}
130-
return s_sQLServerVersion;
130+
return s_sqlServerVersion;
131131
}
132132
}
133133

@@ -491,7 +491,14 @@ public static bool AreConnStringsSetup()
491491

492492
public static bool IsSQL2019() => string.Equals("15", SQLServerVersion.Trim());
493493

494-
public static bool IsSQL2016() => string.Equals("14", s_sQLServerVersion.Trim());
494+
public static bool IsSQL2017() => string.Equals("14", SQLServerVersion.Trim());
495+
496+
public static bool IsSQL2016() => string.Equals("13", SQLServerVersion.Trim());
497+
498+
// "At least" version checks for use as ConditionalFact/ConditionalTheory conditions.
499+
public static bool IsAtLeastSQL2017() => int.TryParse(SQLServerVersion?.Trim(), out int major) && major >= 14;
500+
501+
public static bool IsAtLeastSQL2019() => int.TryParse(SQLServerVersion?.Trim(), out int major) && major >= 15;
495502

496503
public static bool IsSQLAliasSetup()
497504
{

src/Microsoft.Data.SqlClient/tests/ManualTests/SQL/ConnectionTestWithSSLCert/CertificateTest.cs

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ public class CertificateTest : IDisposable
3232

3333
// InstanceName will get replaced with an instance name in the connection string
3434
private static string InstanceName = "MSSQLSERVER";
35-
35+
3636
// s_instanceNamePrefix will get replaced with MSSQL$ is there is an instance name in connection string
3737
private static string InstanceNamePrefix = "";
3838

@@ -51,10 +51,14 @@ private static string ForceEncryptionRegistryPath
5151
{
5252
return $@"SOFTWARE\Microsoft\Microsoft SQL Server\MSSQL15.{InstanceName}\MSSQLSERVER\SuperSocketNetLib";
5353
}
54-
if (DataTestUtility.IsSQL2016())
54+
if (DataTestUtility.IsSQL2017())
5555
{
5656
return $@"SOFTWARE\Microsoft\Microsoft SQL Server\MSSQL14.{InstanceName}\MSSQLSERVER\SuperSocketNetLib";
5757
}
58+
if (DataTestUtility.IsSQL2016())
59+
{
60+
return $@"SOFTWARE\Microsoft\Microsoft SQL Server\MSSQL13.{InstanceName}\MSSQLSERVER\SuperSocketNetLib";
61+
}
5862
return string.Empty;
5963
}
6064
}
@@ -196,7 +200,9 @@ private static void CreateValidCertificate(string script)
196200
RedirectStandardError = true,
197201
RedirectStandardOutput = true,
198202
UseShellExecute = false,
199-
Arguments = $"{script} -Prefix {InstanceNamePrefix} -Instance {InstanceName}",
203+
Arguments = string.IsNullOrEmpty(InstanceNamePrefix)
204+
? $"{script} -Instance \"{InstanceName}\""
205+
: $"{script} -Prefix \"{InstanceNamePrefix}\" -Instance \"{InstanceName}\"",
200206
CreateNoWindow = false,
201207
Verb = "runas"
202208
}
@@ -224,7 +230,12 @@ private static void CreateValidCertificate(string script)
224230
proc.Kill();
225231
// allow async output to process
226232
proc.WaitForExit(2000);
227-
throw new Exception($"Could not generate certificate.Error out put: {output}");
233+
throw new Exception($"Could not generate certificate. Error output: {output}");
234+
}
235+
236+
if (proc.ExitCode != 0)
237+
{
238+
throw new Exception($"Certificate generation script failed with exit code {proc.ExitCode}. Output: {output}");
228239
}
229240
}
230241
else
@@ -252,6 +263,11 @@ private static string GetLocalIpAddress()
252263

253264
private void RemoveCertificate()
254265
{
266+
if (string.IsNullOrEmpty(_thumbprint))
267+
{
268+
return;
269+
}
270+
255271
using X509Store certStore = new(StoreName.Root, StoreLocation.LocalMachine);
256272
certStore.Open(OpenFlags.ReadWrite);
257273
X509Certificate2Collection certCollection = certStore.Certificates.Find(X509FindType.FindByThumbprint, _thumbprint, false);

src/Microsoft.Data.SqlClient/tests/ManualTests/SQL/ConnectionTestWithSSLCert/CertificateTestWithTdsServer.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ private static string ForceEncryptionRegistryPath
8989
{
9090
return $@"SOFTWARE\Microsoft\Microsoft SQL Server\MSSQL15.{s_instanceName}\MSSQLSERVER\SuperSocketNetLib";
9191
}
92-
if (DataTestUtility.IsSQL2016())
92+
if (DataTestUtility.IsSQL2017())
9393
{
9494
return $@"SOFTWARE\Microsoft\Microsoft SQL Server\MSSQL14.{s_instanceName}\MSSQLSERVER\SuperSocketNetLib";
9595
}

src/Microsoft.Data.SqlClient/tests/ManualTests/SQL/SqlBulkCopyTest/CopyAllFromReader.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -61,9 +61,9 @@ public static void Test(string srcConstr, string dstConstr, string dstTable)
6161
DataTestUtility.AssertEqualsWithDescription((long)3, stats["BuffersReceived"], "Unexpected BuffersReceived value.");
6262
DataTestUtility.AssertEqualsWithDescription((long)3, stats["BuffersSent"], "Unexpected BuffersSent value.");
6363
DataTestUtility.AssertEqualsWithDescription((long)0, stats["IduCount"], "Unexpected IduCount value.");
64-
DataTestUtility.AssertEqualsWithDescription((long)6, stats["SelectCount"], "Unexpected SelectCount value.");
64+
DataTestUtility.AssertEqualsWithDescription((long)8, stats["SelectCount"], "Unexpected SelectCount value.");
6565
DataTestUtility.AssertEqualsWithDescription((long)3, stats["ServerRoundtrips"], "Unexpected ServerRoundtrips value.");
66-
DataTestUtility.AssertEqualsWithDescription((long)9, stats["SelectRows"], "Unexpected SelectRows value.");
66+
DataTestUtility.AssertEqualsWithDescription((long)11, stats["SelectRows"], "Unexpected SelectRows value.");
6767
DataTestUtility.AssertEqualsWithDescription((long)2, stats["SumResultSets"], "Unexpected SumResultSets value.");
6868
DataTestUtility.AssertEqualsWithDescription((long)0, stats["Transactions"], "Unexpected Transactions value.");
6969
}

src/Microsoft.Data.SqlClient/tests/ManualTests/SQL/SqlBulkCopyTest/SqlGraphTables.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ namespace Microsoft.Data.SqlClient.ManualTesting.Tests.SqlBulkCopyTests
1111
{
1212
public class SqlGraphTables
1313
{
14-
[ConditionalFact(typeof(DataTestUtility), nameof(DataTestUtility.AreConnStringsSetup), nameof(DataTestUtility.IsNotAzureSynapse))]
14+
[ConditionalFact(typeof(DataTestUtility), nameof(DataTestUtility.AreConnStringsSetup), nameof(DataTestUtility.IsNotAzureSynapse), nameof(DataTestUtility.IsAtLeastSQL2017))]
1515
public void WriteToServer_CopyToSqlGraphNodeTable_Succeeds()
1616
{
1717
string connectionString = DataTestUtility.TCPConnectionString;

0 commit comments

Comments
 (0)