Skip to content

Commit 57b6581

Browse files
authored
Feature | Support SQL Graph column aliases in SqlBulkCopy (dotnet#3677)
1 parent a561c78 commit 57b6581

4 files changed

Lines changed: 236 additions & 9 deletions

File tree

src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlBulkCopy.cs

Lines changed: 88 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -154,12 +154,17 @@ public SourceColumnMetadata(ValueMethod method, bool isSqlType, bool isDataFeed)
154154
// Transaction count has only one value in one column and one row
155155
// MetaData has n columns but no rows
156156
// Collation has 4 columns and n rows
157+
// Column aliases has 3 columns and n rows
157158

158159
private const int MetaDataResultId = 1;
159160

160161
private const int CollationResultId = 2;
161162
private const int CollationId = 3;
162163

164+
private const int ColumnAliasesResultId = 3;
165+
private const int ColumnCanonicalNameColumnId = 0;
166+
private const int ColumnAliasColumnId = 1;
167+
163168
private const int MAX_LENGTH = 0x7FFFFFFF;
164169

165170
private const int DefaultCommandTimeout = 30;
@@ -495,6 +500,15 @@ private string CreateInitialQuery()
495500
//
496501
// See: https://learn.microsoft.com/sql/relational-databases/graphs/sql-graph-architecture#syscolumns
497502
//
503+
// Other columns have aliases assigned to them. The SQL Graph columns $node_id, $edge_id,
504+
// $to_id and $from_id are actually aliases for columns with different canonical names.
505+
// SqlBulkCopy generates these mappings by searching for columns with the below graph_type
506+
// values.
507+
//
508+
// 2 = GRAPH_ID_COMPUTED = $node_id, $edge_id
509+
// 5 = GRAPH_FROM_ID_COMPUTED = $from_id
510+
// 8 = GRAPH_TO_ID_COMPUTED = $to_id
511+
//
498512
// The column-name query is built as dynamic SQL and executed via sp_executesql so
499513
// that it is not compiled (and rejected) on SQL Server versions that lack the
500514
// graph_type column (e.g. SQL 2016). CatalogName and escapedObjectName are
@@ -522,6 +536,13 @@ private string CreateInitialQuery()
522536
DECLARE @Column_Name_Query NVARCHAR(MAX);
523537
DECLARE @Column_Names NVARCHAR(MAX) = NULL;
524538
539+
CREATE TABLE #Column_Aliases
540+
(
541+
[Canonical_Column_Name] SYSNAME,
542+
[Canonical_Column_Id] INT,
543+
[Aliased_Column_Name] SYSNAME
544+
)
545+
525546
IF CAST(SERVERPROPERTY('EngineEdition') AS INT) = 6
526547
BEGIN
527548
SET @Column_Name_Query_SELECT = N'SELECT @Column_Names = STRING_AGG(CAST(QUOTENAME([name]) AS NVARCHAR(MAX)), '', '') WITHIN GROUP (ORDER BY [column_id] ASC)';
@@ -536,6 +557,17 @@ IF CAST(SERVERPROPERTY('EngineEdition') AS INT) = 6
536557
IF EXISTS (SELECT TOP 1 * FROM sys.all_columns WHERE [object_id] = OBJECT_ID('sys.all_columns') AND [name] = 'graph_type')
537558
BEGIN
538559
SET @Column_Name_Query_FILTER = N'WHERE [object_id] = @Object_ID AND COALESCE([graph_type], 0) NOT IN (1, 3, 4, 6, 7)';
560+
561+
EXEC sp_executesql N'
562+
INSERT INTO #Column_Aliases ([Canonical_Column_Name], [Canonical_Column_Id], [Aliased_Column_Name])
563+
SELECT [name], [column_id], ''$to_id'' FROM {CatalogName}.[sys].[all_columns] WHERE [object_id] = @Object_ID AND COALESCE([graph_type], 0) = 8
564+
UNION ALL
565+
SELECT [name], [column_id], ''$from_id'' FROM {CatalogName}.[sys].[all_columns] WHERE [object_id] = @Object_ID AND COALESCE([graph_type], 0) = 5
566+
UNION ALL
567+
SELECT [name], [column_id], ''$edge_id'' FROM {CatalogName}.[sys].[all_columns] WHERE [object_id] = @Object_ID AND COALESCE([graph_type], 0) = 2 AND [name] LIKE ''$edge[_]id[_]%''
568+
UNION ALL
569+
SELECT [name], [column_id], ''$node_id'' FROM {CatalogName}.[sys].[all_columns] WHERE [object_id] = @Object_ID AND COALESCE([graph_type], 0) = 2 AND [name] LIKE ''$node[_]id[_]%''',
570+
N'@Object_ID INT', @Object_ID = @Object_ID
539571
END
540572
ELSE
541573
BEGIN
@@ -551,6 +583,13 @@ IF EXISTS (SELECT TOP 1 * FROM sys.all_columns WHERE [object_id] = OBJECT_ID('sy
551583
SET FMTONLY OFF;
552584
553585
EXEC {CatalogName}..{TableCollationsStoredProc} N'{SchemaName}.{TableName}';
586+
587+
SELECT [Canonical_Column_Name], [Aliased_Column_Name]
588+
FROM #Column_Aliases
589+
WHERE [Aliased_Column_Name] NOT IN (SELECT [name] FROM {CatalogName}.[sys].[all_columns] WHERE [object_id] = @Object_ID)
590+
ORDER BY [Canonical_Column_Id] ASC
591+
592+
DROP TABLE #Column_Aliases
554593
""";
555594
}
556595

@@ -647,9 +686,9 @@ private string AnalyzeTargetAndCreateUpdateBulkCommand(BulkCopySimpleResultSet i
647686
// Keep track of any result columns that we don't have a local
648687
// mapping for.
649688
#if NETFRAMEWORK
650-
HashSet<string> unmatchedColumns = new();
689+
HashSet<string> unmatchedColumns = new(StringComparer.OrdinalIgnoreCase);
651690
#else
652-
HashSet<string> unmatchedColumns = new(_localColumnMappings.Count);
691+
HashSet<string> unmatchedColumns = new(_localColumnMappings.Count, StringComparer.OrdinalIgnoreCase);
653692
#endif
654693

655694
// Start by assuming all locally mapped Destination columns will be
@@ -659,6 +698,50 @@ private string AnalyzeTargetAndCreateUpdateBulkCommand(BulkCopySimpleResultSet i
659698
unmatchedColumns.Add(_localColumnMappings[i].DestinationColumn);
660699
}
661700

701+
// Apply any necessary column aliases. If an aliased name exists in the
702+
// local column mappings but the canonical name does not, update them.
703+
Result columnAliasResults = internalResults[ColumnAliasesResultId];
704+
for (int i = 0; i < columnAliasResults.Count; i++)
705+
{
706+
Row aliasRow = columnAliasResults[i];
707+
SqlString canonicalName = (SqlString)aliasRow[ColumnCanonicalNameColumnId];
708+
SqlString aliasedName = (SqlString)aliasRow[ColumnAliasColumnId];
709+
710+
if (canonicalName.IsNull || aliasedName.IsNull)
711+
{
712+
continue;
713+
}
714+
715+
string canonical = canonicalName.Value;
716+
bool canonicalNameExists = unmatchedColumns.Contains(canonical)
717+
// The destination columns might be escaped. If so, search for those instead
718+
|| unmatchedColumns.Contains(SqlServerEscapeHelper.EscapeIdentifier(canonical));
719+
720+
if (canonicalNameExists)
721+
{
722+
continue;
723+
}
724+
725+
// The canonical name does not exist. Look for a local column mapping which matches
726+
// the alias (or its escaped variant) and replace its name with its canonical name.
727+
string alias = aliasedName.Value;
728+
string escapedAlias = SqlServerEscapeHelper.EscapeIdentifier(alias);
729+
730+
for (int j = 0; j < _localColumnMappings.Count; j++)
731+
{
732+
if (unmatchedColumns.Comparer.Equals(_localColumnMappings[j].DestinationColumn, alias)
733+
|| unmatchedColumns.Comparer.Equals(_localColumnMappings[j].DestinationColumn, escapedAlias))
734+
{
735+
unmatchedColumns.Remove(_localColumnMappings[j].DestinationColumn);
736+
737+
unmatchedColumns.Add(canonical);
738+
_localColumnMappings[j].MappedDestinationColumn = canonical;
739+
740+
break;
741+
}
742+
}
743+
}
744+
662745
// Flag to remember whether or not we need to append a comma before
663746
// the next column in the command text.
664747
bool appendComma = false;
@@ -682,7 +765,7 @@ private string AnalyzeTargetAndCreateUpdateBulkCommand(BulkCopySimpleResultSet i
682765
// Are we missing a mapping between the result column and
683766
// this local column (by ordinal or name)?
684767
if (localColumn._destinationColumnOrdinal != metadata.ordinal
685-
&& UnquotedName(localColumn._destinationColumnName) != metadata.column)
768+
&& UnquotedName(localColumn.MappedDestinationColumn) != metadata.column)
686769
{
687770
// Yes, so move on to the next local column.
688771
continue;
@@ -692,8 +775,8 @@ private string AnalyzeTargetAndCreateUpdateBulkCommand(BulkCopySimpleResultSet i
692775
matched = true;
693776

694777
// Remove it from our unmatched set.
695-
unmatchedColumns.Remove(localColumn.DestinationColumn);
696-
778+
unmatchedColumns.Remove(localColumn.MappedDestinationColumn);
779+
697780
// Check for column types that we refuse to bulk load, even
698781
// though we found a match.
699782
//

src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlBulkCopyColumnMapping.cs

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,13 @@ public sealed class SqlBulkCopyColumnMapping
1919
internal int _internalDestinationColumnOrdinal;
2020
internal int _internalSourceColumnOrdinal; // -1 indicates an undetermined value
2121

22+
// Used by SqlBulkCopy to generate the correct column name after mapping alternate names.
23+
internal string MappedDestinationColumn
24+
{
25+
get => field ?? DestinationColumn;
26+
set => field = value;
27+
}
28+
2229
/// <include file='../../../../../../doc/snippets/Microsoft.Data.SqlClient/SqlBulkCopyColumnMapping.xml' path='docs/members[@name="SqlBulkCopyColumnMapping"]/DestinationColumn/*'/>
2330
public string DestinationColumn
2431
{

src/Microsoft.Data.SqlClient/tests/ManualTests/SQL/SqlBulkCopyTest/CopyAllFromReader.cs

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,10 @@ public static void Test(string srcConstr, string dstConstr, string dstTable)
4040
using (DbDataReader reader = srcCmd.ExecuteReader())
4141
{
4242
IDictionary stats;
43+
long expectedIduCount = DataTestUtility.IsAzureSynapse || DataTestUtility.IsAtLeastSQL2017() ? 1 : 0;
44+
long expectedSelectCount = DataTestUtility.IsAzureSynapse ? 4 : 12;
45+
long expectedSelectRows = DataTestUtility.IsAzureSynapse ? 4 : 14;
46+
long expectedTransactions = DataTestUtility.IsAzureSynapse || DataTestUtility.IsAtLeastSQL2017() ? 1 : 0;
4347
using (SqlBulkCopy bulkcopy = new SqlBulkCopy(dstConn))
4448
{
4549
bulkcopy.DestinationTableName = dstTable;
@@ -60,12 +64,12 @@ public static void Test(string srcConstr, string dstConstr, string dstTable)
6064

6165
DataTestUtility.AssertEqualsWithDescription((long)3, stats["BuffersReceived"], "Unexpected BuffersReceived value.");
6266
DataTestUtility.AssertEqualsWithDescription((long)3, stats["BuffersSent"], "Unexpected BuffersSent value.");
63-
DataTestUtility.AssertEqualsWithDescription((long)0, stats["IduCount"], "Unexpected IduCount value.");
64-
DataTestUtility.AssertEqualsWithDescription((long)11, stats["SelectCount"], "Unexpected SelectCount value.");
67+
DataTestUtility.AssertEqualsWithDescription(expectedIduCount, stats["IduCount"], "Unexpected IduCount value.");
68+
DataTestUtility.AssertEqualsWithDescription(expectedSelectCount, stats["SelectCount"], "Unexpected SelectCount value.");
6569
DataTestUtility.AssertEqualsWithDescription((long)3, stats["ServerRoundtrips"], "Unexpected ServerRoundtrips value.");
66-
DataTestUtility.AssertEqualsWithDescription((long)14, stats["SelectRows"], "Unexpected SelectRows value.");
70+
DataTestUtility.AssertEqualsWithDescription(expectedSelectRows, stats["SelectRows"], "Unexpected SelectRows value.");
6771
DataTestUtility.AssertEqualsWithDescription((long)2, stats["SumResultSets"], "Unexpected SumResultSets value.");
68-
DataTestUtility.AssertEqualsWithDescription((long)0, stats["Transactions"], "Unexpected Transactions value.");
72+
DataTestUtility.AssertEqualsWithDescription(expectedTransactions, stats["Transactions"], "Unexpected Transactions value.");
6973
}
7074
}
7175
}

src/Microsoft.Data.SqlClient/tests/ManualTests/SQL/SqlBulkCopyTest/SqlGraphTables.cs

Lines changed: 133 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,5 +35,138 @@ public void WriteToServer_CopyToSqlGraphNodeTable_Succeeds()
3535
nodeCopy.ColumnMappings.Add("Name", "Name");
3636
nodeCopy.WriteToServer(nodes);
3737
}
38+
39+
[ConditionalFact(typeof(DataTestUtility), nameof(DataTestUtility.AreConnStringsSetup), nameof(DataTestUtility.IsNotAzureSynapse), nameof(DataTestUtility.IsAtLeastSQL2017))]
40+
public void WriteToServer_CopyToAliasedColumnName_Succeeds()
41+
{
42+
string connectionString = DataTestUtility.TCPConnectionString;
43+
44+
using SqlConnection dstConn = new SqlConnection(connectionString);
45+
using DataTable edges = new DataTable()
46+
{
47+
Columns = { new DataColumn("To_ID", typeof(string)), new DataColumn("From_ID", typeof(string)), new DataColumn("Description", typeof(string)) }
48+
};
49+
50+
dstConn.Open();
51+
52+
using Table srcNodeTable = new(dstConn, "SqlGraph_NodeByAlias", "(Id INT PRIMARY KEY IDENTITY(1,1), [Name] VARCHAR(100)) AS NODE");
53+
using Table dstEdgeTable = new(dstConn, "SqlGraph_EdgeByAlias", "([Description] VARCHAR(100)) AS EDGE");
54+
55+
string sampleNodeDataCommand = @$"INSERT INTO {srcNodeTable.Name} ([Name]) SELECT LEFT([name], 100) FROM sys.sysobjects";
56+
using (SqlCommand insertSampleNodes = new(sampleNodeDataCommand, dstConn))
57+
{
58+
insertSampleNodes.ExecuteNonQuery();
59+
}
60+
61+
using (SqlCommand nodeQuery = new SqlCommand($"SELECT $node_id FROM {srcNodeTable.Name}", dstConn))
62+
using (SqlDataReader reader = nodeQuery.ExecuteReader())
63+
{
64+
bool firstRead = reader.Read();
65+
string toId;
66+
string fromId;
67+
68+
Assert.True(firstRead);
69+
toId = reader.GetString(0);
70+
71+
while (reader.Read())
72+
{
73+
fromId = reader.GetString(0);
74+
75+
edges.Rows.Add(toId, fromId, "Test Description");
76+
toId = fromId;
77+
}
78+
}
79+
80+
using (SqlBulkCopy edgeCopy = new(dstConn))
81+
{
82+
edgeCopy.DestinationTableName = dstEdgeTable.Name;
83+
edgeCopy.ColumnMappings.Add("To_ID", "$to_id");
84+
edgeCopy.ColumnMappings.Add("From_ID", "$from_id");
85+
edgeCopy.ColumnMappings.Add("Description", "Description");
86+
87+
edgeCopy.WriteToServer(edges);
88+
}
89+
90+
// Read the values back, comparing to the source DataTable
91+
using SqlCommand dstVerificationCommand = new($"SELECT $to_id, $from_id, [Description] FROM {dstEdgeTable.Name} ORDER BY $to_id ASC", dstConn);
92+
using SqlDataReader dstVerificationReader = dstVerificationCommand.ExecuteReader();
93+
int currentRow = 0;
94+
DataRow[] sortedRows = edges.Select(filterExpression: null, sort: "To_ID ASC");
95+
96+
while (dstVerificationReader.Read())
97+
{
98+
string toId = dstVerificationReader.GetString(0);
99+
string fromId = dstVerificationReader.GetString(1);
100+
string description = dstVerificationReader.GetString(2);
101+
DataRow currSourceRow = sortedRows[currentRow];
102+
103+
Assert.Equal(currSourceRow["To_ID"], toId);
104+
Assert.Equal(currSourceRow["From_ID"], fromId);
105+
Assert.Equal(currSourceRow["Description"], description);
106+
107+
currentRow++;
108+
}
109+
}
110+
111+
[ConditionalFact(typeof(DataTestUtility), nameof(DataTestUtility.AreConnStringsSetup), nameof(DataTestUtility.IsNotAzureSynapse), nameof(DataTestUtility.IsAtLeastSQL2017))]
112+
public void WriteToServer_CopyToTableWithSameNameAsColumnAlias_Succeeds()
113+
{
114+
string connectionString = DataTestUtility.TCPConnectionString;
115+
116+
using SqlConnection dstConn = new SqlConnection(connectionString);
117+
using DataTable nodes = new DataTable()
118+
{
119+
Columns = { new DataColumn("Name", typeof(string)) }
120+
};
121+
122+
dstConn.Open();
123+
124+
for (int i = 0; i < 5; i++)
125+
{
126+
nodes.Rows.Add($"Name {i}");
127+
}
128+
129+
using Table dstGraphTable = new(dstConn, "SqlGraph_NodeWithAlias", "(Id INT PRIMARY KEY IDENTITY(1,1), [Name] VARCHAR(100), [$node_id] VARCHAR(100)) AS NODE");
130+
using Table dstNormalTable = new(dstConn, "NonGraph_NodeWithAlias", "(Id INT PRIMARY KEY IDENTITY(1,1), [Name] VARCHAR(100), [$node_id] VARCHAR(100))");
131+
132+
using (SqlBulkCopy nodeCopy = new SqlBulkCopy(dstConn))
133+
{
134+
nodeCopy.DestinationTableName = dstGraphTable.Name;
135+
nodeCopy.ColumnMappings.Add("Name", "Name");
136+
nodeCopy.ColumnMappings.Add("Name", "$node_id");
137+
nodeCopy.WriteToServer(nodes);
138+
139+
nodeCopy.DestinationTableName = dstNormalTable.Name;
140+
nodeCopy.WriteToServer(nodes);
141+
}
142+
143+
// Read the values back, ensuring that we haven't overwritten the $node_id alias with the contents of the [$node_id] column.
144+
// SELECTing $node_id will read the SQL Graph's node ID, SELECTing [$node_id] will read the column named $node_id.
145+
using (SqlCommand graphVerificationCommand = new SqlCommand($"SELECT Id, $node_id, [$node_id], Name FROM {dstGraphTable.Name}", dstConn))
146+
using (SqlDataReader reader = graphVerificationCommand.ExecuteReader())
147+
{
148+
while (reader.Read())
149+
{
150+
string aliasNodeId = reader.GetString(1);
151+
string physicalNodeId = reader.GetString(2);
152+
string name = reader.GetString(3);
153+
154+
Assert.NotEqual(physicalNodeId, aliasNodeId);
155+
Assert.Equal(name, physicalNodeId);
156+
}
157+
}
158+
159+
using (SqlCommand normalVerificationCommand = new SqlCommand($"SELECT [$node_id], Name FROM {dstNormalTable.Name}", dstConn))
160+
using (SqlDataReader reader = normalVerificationCommand.ExecuteReader())
161+
{
162+
while (reader.Read())
163+
{
164+
string physicalNodeId = reader.GetString(0);
165+
string name = reader.GetString(1);
166+
167+
Assert.Equal(name, physicalNodeId);
168+
}
169+
}
170+
}
38171
}
39172
}

0 commit comments

Comments
 (0)