Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -328,6 +328,11 @@ Optional<StructField> convertMap(Field field, Metadata metadata) {
}
Field key = subFields.get("key");
Field value = subFields.get("value");
// If the BigQuery 'key' field is NULLABLE, this cannot be safely converted
// to a Spark Map. It should remain an Array of Structs.
if (key.getMode() != Field.Mode.REQUIRED) {
return Optional.empty();
}
MapType mapType = DataTypes.createMapType(convert(key).dataType(), convert(value).dataType());
// The returned type is not nullable because the original field is a REPEATED, not NULLABLE.
// There are some compromises we need to do as BigQuery has no native MAP type
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@ private Field getKeyValueRepeatedField() {
return Field.newBuilder(
"foo",
LegacySQLTypeName.RECORD,
Field.of("key", LegacySQLTypeName.INTEGER),
Field.newBuilder("key", LegacySQLTypeName.INTEGER).setMode(Mode.REQUIRED).build(),
Field.of("value", LegacySQLTypeName.STRING))
.setMode(Mode.REPEATED)
.build();
Expand Down Expand Up @@ -467,7 +467,7 @@ public void testConvertBigQueryToSparkArray_mapTypeConversionDisabled() {
.convert(getKeyValueRepeatedField());
StructType elementType =
new StructType()
.add("key", DataTypes.LongType, true)
.add("key", DataTypes.LongType, false)
.add("value", DataTypes.StringType, true);
ArrayType arrayType = new ArrayType(elementType, true);
assertThat(field.dataType()).isEqualTo(arrayType);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,12 @@
*/
package com.google.cloud.spark.bigquery.integration;

import static com.google.cloud.spark.bigquery.integration.TestConstants.GA4_TABLE;
import static com.google.common.truth.Truth.assertThat;
import static org.apache.spark.sql.functions.col;
import static org.apache.spark.sql.functions.concat;
import static org.apache.spark.sql.functions.lit;
import static org.apache.spark.sql.functions.row_number;
import static org.hamcrest.CoreMatchers.is;
import static org.junit.Assume.assumeThat;
import static org.junit.Assume.assumeTrue;
Expand All @@ -42,6 +47,8 @@
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Encoders;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.expressions.Window;
import org.apache.spark.sql.expressions.WindowSpec;
import org.apache.spark.sql.types.DataType;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.StructField;
Expand Down Expand Up @@ -261,7 +268,9 @@ public void testConvertBigQueryMapToSparkMap() throws Exception {
"map_field",
LegacySQLTypeName.RECORD,
FieldList.of(
Field.of("key", LegacySQLTypeName.STRING),
Field.newBuilder("key", LegacySQLTypeName.STRING)
.setMode(Field.Mode.REQUIRED)
.build(),
Field.of("value", LegacySQLTypeName.INTEGER)))
.setMode(Field.Mode.REPEATED)
.build())))
Expand Down Expand Up @@ -316,6 +325,61 @@ public void testTimestampNTZReadFromBigQuery() {
assertThat(row.get(0)).isEqualTo(dateTime);
}

@Test
public void testWindowFunctionPartitionBy() {
WindowSpec windowSpec =
Window.partitionBy("user_pseudo_id", "event_timestamp", "event_name")
.orderBy("event_bundle_sequence_id");

Dataset<Row> df =
spark
.read()
.format("bigquery")
.option("table", GA4_TABLE)
.option("readDataFormat", dataFormat)
.load()
.withColumn("row_num", row_number().over(windowSpec));

Dataset<Row> selectedDF =
df.select("user_pseudo_id", "event_name", "event_timestamp", "row_num");

assertThat(selectedDF.columns().length).isEqualTo(4);
assertThat(
Arrays.stream(df.schema().fields())
.filter(field -> field.name().equals("row_num"))
.count())
.isEqualTo(1);
assertThat(selectedDF.head().get(3)).isEqualTo(1);
}

@Test
public void testWindowFunctionPartitionByWithArray() {
assumeTrue("This test only works for AVRO dataformat", dataFormat.equals("AVRO"));
WindowSpec windowSpec =
Window.partitionBy(concat(col("user_pseudo_id"), col("event_timestamp"), col("event_name")))
.orderBy(lit("window_ordering"));

Dataset<Row> df =
spark
.read()
.format("bigquery")
.option("table", GA4_TABLE)
.option("readDataFormat", dataFormat)
.load()
.withColumn("row_num", row_number().over(windowSpec));

Dataset<Row> selectedDF =
df.select("user_pseudo_id", "event_name", "event_timestamp", "event_params", "row_num");

assertThat(selectedDF.columns().length).isEqualTo(5);
assertThat(
Arrays.stream(df.schema().fields())
.filter(field -> field.name().equals("row_num"))
.count())
.isEqualTo(1);
assertThat(selectedDF.head().get(4)).isEqualTo(1);
}

static <K, V> Map<K, V> scalaMapToJavaMap(scala.collection.Map<K, V> map) {
ImmutableMap.Builder<K, V> result = ImmutableMap.<K, V>builder();
map.foreach(entry -> result.put(entry._1(), entry._2()));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,8 @@ public class TestConstants {
static DataType BQ_NUMERIC = DataTypes.createDecimalType(38, 9);
static DataType BQ_BIGNUMERIC = DataTypes.createDecimalType(38, 38);
public static int BIG_NUMERIC_COLUMN_POSITION = 11;
public static final String GA4_TABLE =
"bigquery-public-data.ga4_obfuscated_sample_ecommerce.events_20210131";

public static StructType ALL_TYPES_TABLE_SCHEMA =
new StructType(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,22 @@
*/
package com.google.cloud.spark.bigquery.integration;

import java.util.Arrays;
import java.util.Collection;
import org.apache.spark.sql.types.DataTypes;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;

@RunWith(Parameterized.class)
public class Spark35ReadByFormatIntegrationTest extends ReadByFormatIntegrationTestBase {

public Spark35ReadByFormatIntegrationTest() {
super("ARROW", /* userProvidedSchemaAllowed */ false, DataTypes.TimestampNTZType);
@Parameterized.Parameters
public static Collection<String> data() {
return Arrays.asList("AVRO", "ARROW");
}

public Spark35ReadByFormatIntegrationTest(String dataFormat) {
super(dataFormat, /* userProvidedSchemaAllowed */ false, DataTypes.TimestampNTZType);
}

// tests are from the super-class
Expand Down
Loading