diff --git a/spark-bigquery-connector-common/src/main/java/com/google/cloud/spark/bigquery/metrics/SparkBigQueryReadSessionMetrics.java b/spark-bigquery-connector-common/src/main/java/com/google/cloud/spark/bigquery/metrics/SparkBigQueryReadSessionMetrics.java index 1a2c306a2b..e7af6b20ed 100644 --- a/spark-bigquery-connector-common/src/main/java/com/google/cloud/spark/bigquery/metrics/SparkBigQueryReadSessionMetrics.java +++ b/spark-bigquery-connector-common/src/main/java/com/google/cloud/spark/bigquery/metrics/SparkBigQueryReadSessionMetrics.java @@ -41,7 +41,6 @@ public class SparkBigQueryReadSessionMetrics extends SparkListener private final LongAccumulator scanTimeAccumulator; private final LongAccumulator parseTimeAccumulator; private final String sessionId; - private final SparkSession sparkSession; private final long timestamp; private final DataFormat readDataFormat; private final DataOrigin dataOrigin; @@ -57,7 +56,6 @@ private SparkBigQueryReadSessionMetrics( DataFormat readDataFormat, DataOrigin dataOrigin, long numReadStreams) { - this.sparkSession = sparkSession; this.sessionId = sessionName; this.timestamp = timestamp; this.readDataFormat = readDataFormat; @@ -230,12 +228,12 @@ public void onJobEnd(SparkListenerJobEnd jobEnd) { Method buildMethod = eventBuilderClass.getDeclaredMethod("build"); - sparkSession + SparkSession.active() .sparkContext() .listenerBus() .post((SparkListenerEvent) buildMethod.invoke(builderInstance)); - sparkSession.sparkContext().removeSparkListener(this); + SparkSession.active().sparkContext().removeSparkListener(this); } catch (ReflectiveOperationException ex) { logger.debug("spark.events.SparkBigQueryConnectorReadSessionEvent library not in class path"); } diff --git a/spark-bigquery-connector-common/src/test/java/com/google/cloud/spark/bigquery/metrics/SparkBigQueryReadSessionMetricsTest.java b/spark-bigquery-connector-common/src/test/java/com/google/cloud/spark/bigquery/metrics/SparkBigQueryReadSessionMetricsTest.java index 62c766c3e8..b8629df0b2 100644 --- a/spark-bigquery-connector-common/src/test/java/com/google/cloud/spark/bigquery/metrics/SparkBigQueryReadSessionMetricsTest.java +++ b/spark-bigquery-connector-common/src/test/java/com/google/cloud/spark/bigquery/metrics/SparkBigQueryReadSessionMetricsTest.java @@ -63,4 +63,30 @@ public void testReadSessionMetricsAccumulator() { metrics.incrementScanTimeAccumulator(5000); assertThat(metrics.getScanTime()).isEqualTo(6000); } + + @Test + public void testSerialization() throws Exception { + String sessionName = "projects/test-project/locations/us/sessions/testSession"; + SparkBigQueryReadSessionMetrics metrics = + SparkBigQueryReadSessionMetrics.from( + spark, + ReadSession.newBuilder().setName(sessionName).build(), + 10L, + DataFormat.ARROW, + DataOrigin.QUERY, + 10L); + + java.io.ByteArrayOutputStream bos = new java.io.ByteArrayOutputStream(); + java.io.ObjectOutputStream out = new java.io.ObjectOutputStream(bos); + out.writeObject(metrics); + out.close(); + + java.io.ByteArrayInputStream bis = new java.io.ByteArrayInputStream(bos.toByteArray()); + java.io.ObjectInputStream in = new java.io.ObjectInputStream(bis); + SparkBigQueryReadSessionMetrics deserializedMetrics = + (SparkBigQueryReadSessionMetrics) in.readObject(); + + assertThat(deserializedMetrics.getNumReadStreams()).isEqualTo(10L); + assertThat(deserializedMetrics.getBytesRead()).isEqualTo(0); + } }