Skip to content

Commit 0be5f0e

Browse files
Merge pull request #92 from pspoerri/ps/fix_pipeline
Fix Github CI build and some minor linting issues.
2 parents b2ae548 + 3b2712a commit 0be5f0e

File tree

6 files changed

+24
-81
lines changed

6 files changed

+24
-81
lines changed

.github/workflows/ci.yml

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ jobs:
5353
scala: 2.12.18
5454
- spark: 3.5.0
5555
scala: 2.13.8
56-
runs-on: ubuntu-latest
56+
runs-on: ubuntu-22.04 # Upgrading this version requires an additional sbt setup step.
5757
env:
5858
SPARK_VERSION: ${{ matrix.spark }}
5959
SCALA_VERSION: ${{ matrix.scala }}
@@ -68,18 +68,17 @@ jobs:
6868
cache: sbt
6969
- name: Check formatting
7070
shell: bash
71-
run: |
71+
run: |-
7272
echo "If either of these checks fail run: 'sbt scalafmtAll && sbt scalafmtSbt'"
7373
sbt scalafmtSbtCheck
7474
sbt scalafmtCheckAll
75-
- name: Test Default Shuffle Fetch
75+
- name: Run tests
7676
shell: bash
77-
if: startsWith(matrix.scala, '2.12.')
7877
run: |
7978
sbt test
80-
- name: Test Spark Shuffle Fetch
79+
- name: Run tests with Spark Shuffle Fetch enabled
8180
shell: bash
82-
if: startsWith(matrix.scala, '2.12.') && !startsWith(matrix.spark, '3.2.')
81+
if: ${{ !startsWith(matrix.spark, '3.2.') }}
8382
env:
8483
USE_SPARK_SHUFFLE_FETCH: "true"
8584
run: |

README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -101,8 +101,8 @@ to Java > 11:
101101
--add-opens=java.base/java.util.concurrent.atomic=ALL-UNNAMED
102102
--add-opens=java.base/sun.nio.ch=ALL-UNNAMED
103103
--add-opens=java.base/sun.nio.cs=ALL-UNNAMED
104-
--add-opens=java.base/sun.security.action=ALL-UNNAMED -
105-
-add-opens=java.base/sun.util.calendar=ALL-UNNAMED
104+
--add-opens=java.base/sun.security.action=ALL-UNNAMED
105+
--add-opens=java.base/sun.util.calendar=ALL-UNNAMED
106106
--add-opens=java.security.jgss/sun.security.krb5=ALL-UNNAMED
107107
```
108108

build.sbt

Lines changed: 4 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,10 @@
33
// SPDX-License-Identifier: Apache2.0
44
//
55

6-
scalaVersion := sys.env.getOrElse("SCALA_VERSION", "2.12.15")
6+
scalaVersion := sys.env.getOrElse("SCALA_VERSION", "2.12.18")
77
organization := "com.ibm"
88
name := "spark-s3-shuffle"
9-
val sparkVersion = sys.env.getOrElse("SPARK_VERSION", "3.3.1")
9+
val sparkVersion = sys.env.getOrElse("SPARK_VERSION", "3.5.0")
1010

1111
enablePlugins(GitVersioning, BuildInfoPlugin)
1212

@@ -29,21 +29,10 @@ buildInfoKeys ++= Seq[BuildInfoKey](
2929
libraryDependencies ++= Seq(
3030
"org.apache.spark" %% "spark-core" % sparkVersion % "provided",
3131
"org.apache.spark" %% "spark-sql" % sparkVersion % "provided",
32-
"org.apache.spark" %% "spark-hadoop-cloud" % sparkVersion % "compile"
32+
"org.apache.spark" %% "spark-hadoop-cloud" % sparkVersion % "compile",
33+
"org.scalatest" %% "scalatest" % "3.2.19" % Test
3334
)
3435

35-
libraryDependencies ++= (if (scalaBinaryVersion.value == "2.12")
36-
Seq(
37-
"junit" % "junit" % "4.13.2" % Test,
38-
"org.scalatest" %% "scalatest" % "3.2.2" % Test,
39-
"ch.cern.sparkmeasure" %% "spark-measure" % "0.18" % Test,
40-
"org.scalacheck" %% "scalacheck" % "1.15.2" % Test,
41-
"org.mockito" % "mockito-core" % "3.4.6" % Test,
42-
"org.scalatestplus" %% "mockito-3-4" % "3.2.9.0" % Test,
43-
"com.github.sbt" % "junit-interface" % "0.13.3" % Test
44-
)
45-
else Seq())
46-
4736
javacOptions ++= Seq("-source", "1.8", "-target", "1.8")
4837
javaOptions ++= Seq("-Xms512M", "-Xmx2048M", "-XX:MaxPermSize=2048M", "-XX:+CMSClassUnloadingEnabled")
4938
scalacOptions ++= Seq("-deprecation", "-unchecked")

src/main/scala/org/apache/spark/shuffle/helper/S3ShuffleDispatcher.scala

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -211,7 +211,6 @@ class S3ShuffleDispatcher extends Logging {
211211
def closeCachedBlocks(shuffleIndex: Int): Unit = {
212212
val filter = (blockId: BlockId) =>
213213
blockId match {
214-
case RDDBlockId(_, _) => false
215214
case ShuffleBlockId(shuffleId, _, _) => shuffleId == shuffleIndex
216215
case ShuffleBlockBatchId(shuffleId, _, _, _) => shuffleId == shuffleIndex
217216
case ShuffleBlockChunkId(shuffleId, _, _, _) => shuffleId == shuffleIndex
@@ -223,14 +222,9 @@ class S3ShuffleDispatcher extends Logging {
223222
case ShuffleMergedDataBlockId(_, shuffleId, _, _) => shuffleId == shuffleIndex
224223
case ShuffleMergedIndexBlockId(_, shuffleId, _, _) => shuffleId == shuffleIndex
225224
case ShuffleMergedMetaBlockId(_, shuffleId, _, _) => shuffleId == shuffleIndex
226-
case BroadcastBlockId(_, _) => false
227-
case TaskResultBlockId(_) => false
228-
case StreamBlockId(_, _) => false
229-
case TempLocalBlockId(_) => false
230-
case TempShuffleBlockId(_) => false
231-
case TestBlockId(_) => false
225+
case _ => false
232226
}
233-
cachedFileStatus.remove(filter, _)
227+
cachedFileStatus.remove(filter, None)
234228
}
235229

236230
/** Open a block for writing.

src/main/scala/org/apache/spark/shuffle/sort/S3ShuffleManager.scala

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -23,20 +23,15 @@
2323
package org.apache.spark.shuffle.sort
2424

2525
import com.ibm.SparkS3ShuffleBuild
26-
import org.apache.hadoop.fs.{Path, PathFilter}
2726
import org.apache.spark._
28-
import org.apache.spark.internal.{Logging, config}
27+
import org.apache.spark.internal.Logging
2928
import org.apache.spark.shuffle._
3029
import org.apache.spark.shuffle.api.ShuffleExecutorComponents
3130
import org.apache.spark.shuffle.helper.{S3ShuffleDispatcher, S3ShuffleHelper}
3231
import org.apache.spark.storage.S3ShuffleReader
3332

34-
import java.io.IOException
3533
import scala.collection.JavaConverters._
3634
import scala.collection.mutable
37-
import scala.concurrent.ExecutionContext.Implicits.global
38-
import scala.concurrent.duration.Duration
39-
import scala.concurrent.{Await, Future}
4035

4136
/** This class was adapted from Apache Spark: SortShuffleManager.scala
4237
*/

src/test/scala-2.12/org/apache/spark/shuffle/S3ShuffleManagerTest.scala renamed to src/test/scala/org/apache/spark/shuffle/S3ShuffleManagerTest.scala

Lines changed: 10 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -22,11 +22,11 @@
2222

2323
package org.apache.spark.shuffle
2424

25-
import ch.cern.sparkmeasure.StageMetrics
2625
import org.apache.spark._
2726
import org.apache.spark.sql.SparkSession
28-
import org.junit.Test
29-
import org.scalatest.Assertions._
27+
import org.scalatest._
28+
import org.scalatest.funsuite.AnyFunSuite
29+
import org.scalatest.matchers.should.Matchers
3030

3131
import java.util.UUID
3232

@@ -39,24 +39,21 @@ case class CombinerClass()
3939
/*
4040
* The test has been adapted from the following pull request https://github.com/apache/spark/pull/34864/files .
4141
*/
42-
class S3ShuffleManagerTest {
42+
class S3ShuffleManagerTest extends AnyFunSuite {
4343

44-
@Test
45-
def foldByKey(): Unit = {
44+
test("foldByKey") {
4645
val conf = newSparkConf()
4746
runWithSparkConf(conf)
4847
}
4948

50-
@Test
51-
def foldByKey_zeroBuffering(): Unit = {
49+
test("foldByKey_zeroBuffering") {
5250
val conf = newSparkConf()
5351
conf.set("spark.reducer.maxSizeInFlight", "0")
5452
conf.set("spark.network.maxRemoteBlockSizeFetchToMem", "0")
5553
runWithSparkConf(conf)
5654
}
5755

58-
@Test
59-
def runWithSparkConf_noMapSideCombine(): Unit = {
56+
test("runWithSparkConf_noMapSideCombine") {
6057
val conf = newSparkConf()
6158
conf.set("spark.shuffle.sort.bypassMergeThreshold", "1000")
6259
val sc = new SparkContext(conf)
@@ -75,8 +72,7 @@ class S3ShuffleManagerTest {
7572
}
7673
}
7774

78-
@Test
79-
def forceSortShuffle(): Unit = {
75+
test("forceSortShuffle") {
8076
val conf = newSparkConf()
8177
conf.set("spark.shuffle.sort.bypassMergeThreshold", "1")
8278
val sc = new SparkContext(conf)
@@ -104,8 +100,7 @@ class S3ShuffleManagerTest {
104100
}
105101
}
106102

107-
@Test
108-
def testCombineByKey(): Unit = {
103+
test("testCombineByKey") {
109104
val conf = newSparkConf()
110105
val sc = new SparkContext(conf)
111106
try {
@@ -148,8 +143,7 @@ class S3ShuffleManagerTest {
148143
}
149144
}
150145

151-
@Test
152-
def teraSortLike(): Unit = {
146+
test("teraSortLike") {
153147
val conf = newSparkConf()
154148
conf.set("spark.shuffle.sort.bypassMergeThreshold", "1")
155149
val sc = new SparkContext(conf)
@@ -179,34 +173,6 @@ class S3ShuffleManagerTest {
179173
}
180174
}
181175

182-
@Test
183-
def runWithSparkMeasure(): Unit = {
184-
val conf = newSparkConf()
185-
val sc = new SparkContext(conf)
186-
val spark = SparkSession.builder().sparkContext(sc).getOrCreate()
187-
val stageMetrics = StageMetrics(spark)
188-
val result = stageMetrics.runAndMeasure {
189-
spark.sql("select count(*) from range(1000) cross join range(1000) cross join range(1000)").take(1)
190-
}
191-
assert(result.map(r => r.getLong(0)).head === 1000000000)
192-
193-
val timestamp = System.currentTimeMillis()
194-
stageMetrics.createStageMetricsDF(s"spark_measure_test_${timestamp}")
195-
val metrics = stageMetrics.aggregateStageMetrics(s"spark_measure_test_${timestamp}")
196-
// get all of the stats
197-
val (runTime, bytesRead, recordsRead, bytesWritten, recordsWritten) =
198-
metrics
199-
.select("elapsedTime", "bytesRead", "recordsRead", "bytesWritten", "recordsWritten")
200-
.take(1)
201-
.map(r => (r.getLong(0), r.getLong(1), r.getLong(2), r.getLong(3), r.getLong(4)))
202-
.head
203-
println(
204-
f"Elapsed: ${runTime}, bytesRead: ${bytesRead}, recordsRead: ${recordsRead}, bytesWritten ${bytesWritten}, recordsWritten: ${recordsWritten}"
205-
)
206-
spark.stop()
207-
spark.close()
208-
}
209-
210176
private def runWithSparkConf(conf: SparkConf) = {
211177
val sc = new SparkContext(conf)
212178

0 commit comments

Comments
 (0)