Skip to content

Commit 6bd76a5

Browse files
authored
Add test window function to ReadByFormatIntegrationTestBase (#1447)
* Add testWindowFunctionPartitionBy in ReadByFormatIntegrationTestBase for public GA4 dataset * Change format to AVRO to bypass the exception. * Change format to AVRO to bypass the exception. * Fix the schema converter for array type. * Revert the local changes * Format and fix tests * Refactor according to gemini-code-assist. * Add Parameterized tests for ARROW and AVRO format in Spark35ReadByFormatIntegrationTest. * Rename test for AVRO format.
1 parent bb9f7c1 commit 6bd76a5

File tree

5 files changed

+86
-5
lines changed

5 files changed

+86
-5
lines changed

spark-bigquery-connector-common/src/main/java/com/google/cloud/spark/bigquery/SchemaConverters.java

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -328,6 +328,11 @@ Optional<StructField> convertMap(Field field, Metadata metadata) {
328328
}
329329
Field key = subFields.get("key");
330330
Field value = subFields.get("value");
331+
// If the BigQuery 'key' field is NULLABLE, this cannot be safely converted
332+
// to a Spark Map. It should remain an Array of Structs.
333+
if (key.getMode() != Field.Mode.REQUIRED) {
334+
return Optional.empty();
335+
}
331336
MapType mapType = DataTypes.createMapType(convert(key).dataType(), convert(value).dataType());
332337
// The returned type is not nullable because the original field is a REPEATED, not NULLABLE.
333338
// There are some compromises we need to do as BigQuery has no native MAP type

spark-bigquery-connector-common/src/test/java/com/google/cloud/spark/bigquery/SchemaConverterTest.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -197,7 +197,7 @@ private Field getKeyValueRepeatedField() {
197197
return Field.newBuilder(
198198
"foo",
199199
LegacySQLTypeName.RECORD,
200-
Field.of("key", LegacySQLTypeName.INTEGER),
200+
Field.newBuilder("key", LegacySQLTypeName.INTEGER).setMode(Mode.REQUIRED).build(),
201201
Field.of("value", LegacySQLTypeName.STRING))
202202
.setMode(Mode.REPEATED)
203203
.build();
@@ -467,7 +467,7 @@ public void testConvertBigQueryToSparkArray_mapTypeConversionDisabled() {
467467
.convert(getKeyValueRepeatedField());
468468
StructType elementType =
469469
new StructType()
470-
.add("key", DataTypes.LongType, true)
470+
.add("key", DataTypes.LongType, false)
471471
.add("value", DataTypes.StringType, true);
472472
ArrayType arrayType = new ArrayType(elementType, true);
473473
assertThat(field.dataType()).isEqualTo(arrayType);

spark-bigquery-connector-common/src/test/java/com/google/cloud/spark/bigquery/integration/ReadByFormatIntegrationTestBase.java

Lines changed: 65 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,12 @@
1515
*/
1616
package com.google.cloud.spark.bigquery.integration;
1717

18+
import static com.google.cloud.spark.bigquery.integration.TestConstants.GA4_TABLE;
1819
import static com.google.common.truth.Truth.assertThat;
20+
import static org.apache.spark.sql.functions.col;
21+
import static org.apache.spark.sql.functions.concat;
22+
import static org.apache.spark.sql.functions.lit;
23+
import static org.apache.spark.sql.functions.row_number;
1924
import static org.hamcrest.CoreMatchers.is;
2025
import static org.junit.Assume.assumeThat;
2126
import static org.junit.Assume.assumeTrue;
@@ -42,6 +47,8 @@
4247
import org.apache.spark.sql.Dataset;
4348
import org.apache.spark.sql.Encoders;
4449
import org.apache.spark.sql.Row;
50+
import org.apache.spark.sql.expressions.Window;
51+
import org.apache.spark.sql.expressions.WindowSpec;
4552
import org.apache.spark.sql.types.DataType;
4653
import org.apache.spark.sql.types.DataTypes;
4754
import org.apache.spark.sql.types.StructField;
@@ -261,7 +268,9 @@ public void testConvertBigQueryMapToSparkMap() throws Exception {
261268
"map_field",
262269
LegacySQLTypeName.RECORD,
263270
FieldList.of(
264-
Field.of("key", LegacySQLTypeName.STRING),
271+
Field.newBuilder("key", LegacySQLTypeName.STRING)
272+
.setMode(Field.Mode.REQUIRED)
273+
.build(),
265274
Field.of("value", LegacySQLTypeName.INTEGER)))
266275
.setMode(Field.Mode.REPEATED)
267276
.build())))
@@ -316,6 +325,61 @@ public void testTimestampNTZReadFromBigQuery() {
316325
assertThat(row.get(0)).isEqualTo(dateTime);
317326
}
318327

328+
@Test
329+
public void testWindowFunctionPartitionBy() {
330+
WindowSpec windowSpec =
331+
Window.partitionBy("user_pseudo_id", "event_timestamp", "event_name")
332+
.orderBy("event_bundle_sequence_id");
333+
334+
Dataset<Row> df =
335+
spark
336+
.read()
337+
.format("bigquery")
338+
.option("table", GA4_TABLE)
339+
.option("readDataFormat", dataFormat)
340+
.load()
341+
.withColumn("row_num", row_number().over(windowSpec));
342+
343+
Dataset<Row> selectedDF =
344+
df.select("user_pseudo_id", "event_name", "event_timestamp", "row_num");
345+
346+
assertThat(selectedDF.columns().length).isEqualTo(4);
347+
assertThat(
348+
Arrays.stream(df.schema().fields())
349+
.filter(field -> field.name().equals("row_num"))
350+
.count())
351+
.isEqualTo(1);
352+
assertThat(selectedDF.head().get(3)).isEqualTo(1);
353+
}
354+
355+
@Test
356+
public void testWindowFunctionPartitionByWithArray() {
357+
assumeTrue("This test only works for AVRO dataformat", dataFormat.equals("AVRO"));
358+
WindowSpec windowSpec =
359+
Window.partitionBy(concat(col("user_pseudo_id"), col("event_timestamp"), col("event_name")))
360+
.orderBy(lit("window_ordering"));
361+
362+
Dataset<Row> df =
363+
spark
364+
.read()
365+
.format("bigquery")
366+
.option("table", GA4_TABLE)
367+
.option("readDataFormat", dataFormat)
368+
.load()
369+
.withColumn("row_num", row_number().over(windowSpec));
370+
371+
Dataset<Row> selectedDF =
372+
df.select("user_pseudo_id", "event_name", "event_timestamp", "event_params", "row_num");
373+
374+
assertThat(selectedDF.columns().length).isEqualTo(5);
375+
assertThat(
376+
Arrays.stream(df.schema().fields())
377+
.filter(field -> field.name().equals("row_num"))
378+
.count())
379+
.isEqualTo(1);
380+
assertThat(selectedDF.head().get(4)).isEqualTo(1);
381+
}
382+
319383
static <K, V> Map<K, V> scalaMapToJavaMap(scala.collection.Map<K, V> map) {
320384
ImmutableMap.Builder<K, V> result = ImmutableMap.<K, V>builder();
321385
map.foreach(entry -> result.put(entry._1(), entry._2()));

spark-bigquery-connector-common/src/test/java/com/google/cloud/spark/bigquery/integration/TestConstants.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,8 @@ public class TestConstants {
133133
static DataType BQ_NUMERIC = DataTypes.createDecimalType(38, 9);
134134
static DataType BQ_BIGNUMERIC = DataTypes.createDecimalType(38, 38);
135135
public static int BIG_NUMERIC_COLUMN_POSITION = 11;
136+
public static final String GA4_TABLE =
137+
"bigquery-public-data.ga4_obfuscated_sample_ecommerce.events_20210131";
136138

137139
public static StructType ALL_TYPES_TABLE_SCHEMA =
138140
new StructType(

spark-bigquery-dsv2/spark-3.5-bigquery/src/test/java/com/google/cloud/spark/bigquery/integration/Spark35ReadByFormatIntegrationTest.java

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,22 @@
1515
*/
1616
package com.google.cloud.spark.bigquery.integration;
1717

18+
import java.util.Arrays;
19+
import java.util.Collection;
1820
import org.apache.spark.sql.types.DataTypes;
21+
import org.junit.runner.RunWith;
22+
import org.junit.runners.Parameterized;
1923

24+
@RunWith(Parameterized.class)
2025
public class Spark35ReadByFormatIntegrationTest extends ReadByFormatIntegrationTestBase {
2126

22-
public Spark35ReadByFormatIntegrationTest() {
23-
super("ARROW", /* userProvidedSchemaAllowed */ false, DataTypes.TimestampNTZType);
27+
@Parameterized.Parameters
28+
public static Collection<String> data() {
29+
return Arrays.asList("AVRO", "ARROW");
30+
}
31+
32+
public Spark35ReadByFormatIntegrationTest(String dataFormat) {
33+
super(dataFormat, /* userProvidedSchemaAllowed */ false, DataTypes.TimestampNTZType);
2434
}
2535

2636
// tests are from the super-class

0 commit comments

Comments
 (0)