Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,10 @@
package com.google.cloud.spark.bigquery.integration;

import static com.google.common.truth.Truth.assertThat;
import static org.apache.spark.sql.functions.col;
import static org.apache.spark.sql.functions.concat;
import static org.apache.spark.sql.functions.lit;
import static org.apache.spark.sql.functions.row_number;
import static org.hamcrest.CoreMatchers.is;
import static org.junit.Assume.assumeThat;
import static org.junit.Assume.assumeTrue;
Expand All @@ -42,6 +46,8 @@
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Encoders;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.expressions.Window;
import org.apache.spark.sql.expressions.WindowSpec;
import org.apache.spark.sql.types.DataType;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.StructField;
Expand Down Expand Up @@ -316,6 +322,34 @@ public void testTimestampNTZReadFromBigQuery() {
assertThat(row.get(0)).isEqualTo(dateTime);
}

@Test
public void testWindowFunctionPartitionBy() {
String tableName = "bigquery-public-data.ga4_obfuscated_sample_ecommerce.events_20210131";
WindowSpec windowSpec =
Window.partitionBy(concat(col("user_pseudo_id"), col("event_timestamp"), col("event_name")))
.orderBy(lit("window_ordering"));

Dataset<Row> df =
spark
.read()
.format("bigquery")
.option("table", tableName)
.option("readDataFormat", dataFormat)
.load()
.withColumn("row_num", row_number().over(windowSpec));

Dataset<Row> selectedDF =
df.select("user_pseudo_id", "event_name", "event_timestamp", "row_num");

assertThat(selectedDF.columns().length).isEqualTo(4);
assertThat(
Arrays.stream(df.schema().fields())
.filter(field -> field.name().equals("row_num"))
.count())
.isEqualTo(1);
assertThat(selectedDF.head().get(3)).isEqualTo(1);
}

static <K, V> Map<K, V> scalaMapToJavaMap(scala.collection.Map<K, V> map) {
ImmutableMap.Builder<K, V> result = ImmutableMap.<K, V>builder();
map.foreach(entry -> result.put(entry._1(), entry._2()));
Expand Down
Loading