Skip to content

Commit 9635fb4

Browse files
authored
Spark: Use delimited column names in CreateChangelogViewProcedure (#12418)
1 parent c879eed commit 9635fb4

File tree

4 files changed

+143
-4
lines changed

4 files changed

+143
-4
lines changed

spark/v3.4/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestCreateChangelogViewProcedure.java

+45
Original file line numberDiff line numberDiff line change
@@ -20,13 +20,16 @@
2020

2121
import static org.junit.Assert.assertThrows;
2222

23+
import java.util.Arrays;
2324
import java.util.List;
2425
import java.util.Map;
26+
import java.util.stream.Collectors;
2527
import org.apache.iceberg.ChangelogOperation;
2628
import org.apache.iceberg.Snapshot;
2729
import org.apache.iceberg.Table;
2830
import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList;
2931
import org.apache.iceberg.spark.SparkReadOptions;
32+
import org.apache.spark.sql.types.StructField;
3033
import org.junit.After;
3134
import org.junit.Assert;
3235
import org.junit.Test;
@@ -95,6 +98,48 @@ public void testCustomizedViewName() {
9598
Assert.assertEquals(2, rowCount);
9699
}
97100

101+
@Test
102+
public void testNonStandardColumnNames() {
103+
sql("CREATE TABLE %s (`the id` INT, `the.data` STRING) USING iceberg", tableName);
104+
sql("ALTER TABLE %s ADD PARTITION FIELD `the.data`", tableName);
105+
106+
sql("INSERT INTO %s VALUES (1, 'a')", tableName);
107+
sql("INSERT INTO %s VALUES (2, 'b')", tableName);
108+
109+
Table table = validationCatalog.loadTable(tableIdent);
110+
111+
Snapshot snap1 = table.currentSnapshot();
112+
113+
sql("INSERT OVERWRITE %s VALUES (-2, 'b')", tableName);
114+
115+
table.refresh();
116+
117+
Snapshot snap2 = table.currentSnapshot();
118+
119+
sql(
120+
"CALL %s.system.create_changelog_view("
121+
+ "table => '%s',"
122+
+ "options => map('%s','%s','%s','%s'),"
123+
+ "changelog_view => '%s')",
124+
catalogName,
125+
tableName,
126+
SparkReadOptions.START_SNAPSHOT_ID,
127+
snap1.snapshotId(),
128+
SparkReadOptions.END_SNAPSHOT_ID,
129+
snap2.snapshotId(),
130+
"cdc_view");
131+
132+
var df = spark.sql("select * from cdc_view");
133+
var fieldNames =
134+
Arrays.stream(df.schema().fields()).map(StructField::name).collect(Collectors.toList());
135+
136+
Assert.assertEquals(
137+
"Result Schema should match",
138+
List.of("the id", "the.data", "_change_type", "_change_ordinal", "_commit_snapshot_id"),
139+
fieldNames);
140+
Assert.assertEquals("Result Row Count should match", 2, df.collectAsList().size());
141+
}
142+
98143
@Test
99144
public void testNoSnapshotIdInput() {
100145
createTableWithTwoColumns();

spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/procedures/CreateChangelogViewProcedure.java

+27-2
Original file line numberDiff line numberDiff line change
@@ -197,7 +197,12 @@ private Dataset<Row> removeCarryoverRows(Dataset<Row> df, boolean netChanges) {
197197
}
198198

199199
Column[] repartitionSpec =
200-
Arrays.stream(df.columns()).filter(columnsToKeep).map(df::col).toArray(Column[]::new);
200+
Arrays.stream(df.columns())
201+
.filter(columnsToKeep)
202+
.map(CreateChangelogViewProcedure::delimitedName)
203+
.map(df::col)
204+
.toArray(Column[]::new);
205+
201206
return applyCarryoverRemoveIterator(df, repartitionSpec, netChanges);
202207
}
203208

@@ -206,7 +211,9 @@ private String[] identifierColumns(ProcedureInput input, Identifier tableIdent)
206211
return input.asStringArray(IDENTIFIER_COLUMNS_PARAM);
207212
} else {
208213
Table table = loadSparkTable(tableIdent).table();
209-
return table.schema().identifierFieldNames().toArray(new String[0]);
214+
return table.schema().identifierFieldNames().stream()
215+
.map(CreateChangelogViewProcedure::delimitedName)
216+
.toArray(String[]::new);
210217
}
211218
}
212219

@@ -257,6 +264,24 @@ private Dataset<Row> applyCarryoverRemoveIterator(
257264
RowEncoder.apply(schema));
258265
}
259266

267+
/**
268+
* Ensure that column can be referenced using this name. Issues may come from field names that
269+
* contain non-standard characters. In Spark, this can be fixed by using <a
270+
* href="https://spark.apache.org/docs/3.5.0/sql-ref-identifier.html#delimited-identifier">backtick
271+
* quotes</a>.
272+
*
273+
* @param columnName Column name that potentially can contain non-standard characters.
274+
* @return A name that can be safely used within Spark to reference a column by its name.
275+
*/
276+
private static String delimitedName(String columnName) {
277+
var delimited = columnName.startsWith("`") && columnName.endsWith("`");
278+
if (delimited) {
279+
return columnName;
280+
} else {
281+
return "`" + columnName.replaceAll("`", "``") + "`";
282+
}
283+
}
284+
260285
private static Column[] sortSpec(Dataset<Row> df, Column[] repartitionSpec, boolean netChanges) {
261286
Column changeType = df.col(MetadataColumns.CHANGE_TYPE.name());
262287
Column changeOrdinal = df.col(MetadataColumns.CHANGE_ORDINAL.name());

spark/v3.5/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestCreateChangelogViewProcedure.java

+44
Original file line numberDiff line numberDiff line change
@@ -21,13 +21,16 @@
2121
import static org.assertj.core.api.Assertions.assertThat;
2222
import static org.assertj.core.api.Assertions.assertThatThrownBy;
2323

24+
import java.util.Arrays;
2425
import java.util.List;
26+
import java.util.stream.Collectors;
2527
import org.apache.iceberg.ChangelogOperation;
2628
import org.apache.iceberg.ParameterizedTestExtension;
2729
import org.apache.iceberg.Snapshot;
2830
import org.apache.iceberg.Table;
2931
import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList;
3032
import org.apache.iceberg.spark.SparkReadOptions;
33+
import org.apache.spark.sql.types.StructField;
3134
import org.junit.jupiter.api.AfterEach;
3235
import org.junit.jupiter.api.TestTemplate;
3336
import org.junit.jupiter.api.extension.ExtendWith;
@@ -92,6 +95,47 @@ public void testCustomizedViewName() {
9295
assertThat(rowCount).isEqualTo(2);
9396
}
9497

98+
@TestTemplate
99+
public void testNonStandardColumnNames() {
100+
sql("CREATE TABLE %s (`the id` INT, `the.data` STRING) USING iceberg", tableName);
101+
sql("ALTER TABLE %s ADD PARTITION FIELD `the.data`", tableName);
102+
103+
sql("INSERT INTO %s VALUES (1, 'a')", tableName);
104+
sql("INSERT INTO %s VALUES (2, 'b')", tableName);
105+
106+
Table table = validationCatalog.loadTable(tableIdent);
107+
108+
Snapshot snap1 = table.currentSnapshot();
109+
110+
sql("INSERT OVERWRITE %s VALUES (-2, 'b')", tableName);
111+
112+
table.refresh();
113+
114+
Snapshot snap2 = table.currentSnapshot();
115+
116+
sql(
117+
"CALL %s.system.create_changelog_view("
118+
+ "table => '%s',"
119+
+ "options => map('%s','%s','%s','%s'),"
120+
+ "changelog_view => '%s')",
121+
catalogName,
122+
tableName,
123+
SparkReadOptions.START_SNAPSHOT_ID,
124+
snap1.snapshotId(),
125+
SparkReadOptions.END_SNAPSHOT_ID,
126+
snap2.snapshotId(),
127+
"cdc_view");
128+
129+
var df = spark.sql("select * from cdc_view");
130+
var fieldNames =
131+
Arrays.stream(df.schema().fields()).map(StructField::name).collect(Collectors.toList());
132+
assertThat(fieldNames)
133+
.containsExactly(
134+
"the id", "the.data", "_change_type", "_change_ordinal", "_commit_snapshot_id");
135+
136+
assertThat(df.collectAsList()).hasSize(2);
137+
}
138+
95139
@TestTemplate
96140
public void testNoSnapshotIdInput() {
97141
createTableWithTwoColumns();

spark/v3.5/spark/src/main/java/org/apache/iceberg/spark/procedures/CreateChangelogViewProcedure.java

+27-2
Original file line numberDiff line numberDiff line change
@@ -210,7 +210,12 @@ private Dataset<Row> removeCarryoverRows(Dataset<Row> df, boolean netChanges) {
210210
}
211211

212212
Column[] repartitionSpec =
213-
Arrays.stream(df.columns()).filter(columnsToKeep).map(df::col).toArray(Column[]::new);
213+
Arrays.stream(df.columns())
214+
.filter(columnsToKeep)
215+
.map(CreateChangelogViewProcedure::delimitedName)
216+
.map(df::col)
217+
.toArray(Column[]::new);
218+
214219
return applyCarryoverRemoveIterator(df, repartitionSpec, netChanges);
215220
}
216221

@@ -219,7 +224,9 @@ private String[] identifierColumns(ProcedureInput input, Identifier tableIdent)
219224
return input.asStringArray(IDENTIFIER_COLUMNS_PARAM);
220225
} else {
221226
Table table = loadSparkTable(tableIdent).table();
222-
return table.schema().identifierFieldNames().toArray(new String[0]);
227+
return table.schema().identifierFieldNames().stream()
228+
.map(CreateChangelogViewProcedure::delimitedName)
229+
.toArray(String[]::new);
223230
}
224231
}
225232

@@ -270,6 +277,24 @@ private Dataset<Row> applyCarryoverRemoveIterator(
270277
Encoders.row(schema));
271278
}
272279

280+
/**
281+
* Ensure that column can be referenced using this name. Issues may come from field names that
282+
* contain non-standard characters. In Spark, this can be fixed by using <a
283+
* href="https://spark.apache.org/docs/3.5.0/sql-ref-identifier.html#delimited-identifier">backtick
284+
* quotes</a>.
285+
*
286+
* @param columnName Column name that potentially can contain non-standard characters.
287+
* @return A name that can be safely used within Spark to reference a column by its name.
288+
*/
289+
private static String delimitedName(String columnName) {
290+
var delimited = columnName.startsWith("`") && columnName.endsWith("`");
291+
if (delimited) {
292+
return columnName;
293+
} else {
294+
return "`" + columnName.replaceAll("`", "``") + "`";
295+
}
296+
}
297+
273298
private static Column[] sortSpec(Dataset<Row> df, Column[] repartitionSpec, boolean netChanges) {
274299
Column changeType = df.col(MetadataColumns.CHANGE_TYPE.name());
275300
Column changeOrdinal = df.col(MetadataColumns.CHANGE_ORDINAL.name());

0 commit comments

Comments
 (0)