Skip to content

Commit 16e29d5

Browse files
committed
support external shuffle
Signed-off-by: zizhao <zizhao@nvidia.com>
1 parent ccf241b commit 16e29d5

34 files changed

+3440
-45
lines changed

pom.xml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ See file LICENSE for terms.
9898
<includes>
9999
<include>org/apache/spark/network/**</include>
100100
<include>org/apache/spark/shuffle/ucx/external/**</include>
101-
<include>org/apache/spark/shuffle/ucx/memory/**</include>
101+
<include>org/apache/spark/shuffle/ucx/memory/UcxLimitedMemPool.scala</include>
102102
<include>org/apache/spark/shuffle/ucx/ShuffleTransport.scala</include>
103103
<include>org/apache/spark/shuffle/utils/**</include>
104104
</includes>
@@ -141,7 +141,7 @@ See file LICENSE for terms.
141141
<includes>
142142
<include>org/apache/spark/network/**</include>
143143
<include>org/apache/spark/shuffle/ucx/external/**</include>
144-
<include>org/apache/spark/shuffle/ucx/memory/**</include>
144+
<include>org/apache/spark/shuffle/ucx/memory/UcxLimitedMemPool.scala</include>
145145
<include>org/apache/spark/shuffle/ucx/ShuffleTransport.scala</include>
146146
<include>org/apache/spark/shuffle/utils/**</include>
147147
</includes>
Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
package org.apache.spark.network.shuffle
2+
3+
import java.io.File
4+
import java.nio.charset.StandardCharsets
5+
import java.lang.reflect.{Method, Field}
6+
7+
import scala.collection.mutable
8+
9+
import com.fasterxml.jackson.databind.ObjectMapper
10+
11+
import org.apache.hadoop.metrics2.lib.DefaultMetricsSystem;
12+
import org.apache.hadoop.metrics2.impl.MetricsSystemImpl;
13+
14+
import org.apache.spark.network.shuffle.protocol.ExecutorShuffleInfo
15+
import org.apache.spark.network.util.TransportConf
16+
import org.apache.spark.network.shuffle.ExternalShuffleBlockResolver.AppExecId
17+
18+
import org.apache.spark.shuffle.utils.UcxLogging
19+
import org.apache.spark.shuffle.ucx.ExternalUcxServerTransport
20+
21+
class ExternalUcxShuffleBlockResolver(conf: TransportConf, registeredExecutorFile: File)
22+
extends ExternalShuffleBlockResolver(conf, registeredExecutorFile) with UcxLogging {
23+
private[spark] final val APP_KEY_PREFIX = "AppExecShuffleInfo";
24+
private[spark] final val ucxMapper = new ObjectMapper
25+
private[spark] var dbAppExecKeyMethod: Method = _
26+
private[spark] val knownManagers = mutable.Set(
27+
"org.apache.spark.shuffle.sort.SortShuffleManager",
28+
"org.apache.spark.shuffle.unsafe.UnsafeShuffleManager",
29+
"org.apache.spark.shuffle.ExternalUcxShuffleManager")
30+
private[spark] var ucxTransport: ExternalUcxServerTransport = _
31+
32+
// init()
33+
34+
private[spark] def dbAppExecKey(appExecId: AppExecId): Array[Byte] = {
35+
// we stick a common prefix on all the keys so we can find them in the DB
36+
val appExecJson = ucxMapper.writeValueAsString(appExecId);
37+
val key = (APP_KEY_PREFIX + ";" + appExecJson);
38+
key.getBytes(StandardCharsets.UTF_8);
39+
}
40+
41+
// def init(): Unit = {
42+
// val clazz = Class.forName("org.apache.spark.network.shuffle.ExternalShuffleBlockResolver")
43+
// try {
44+
// dbAppExecKeyMethod = clazz.getDeclaredMethod("dbAppExecKey", classOf[AppExecId])
45+
// dbAppExecKeyMethod.setAccessible(true)
46+
// } catch {
47+
// case e: Exception => {
48+
// logError(s"Get dbAppExecKey from ExternalUcxShuffleBlockResolver failed: $e")
49+
// }
50+
// }
51+
// }
52+
53+
// def dbAppExecKey(fullId: AppExecId): Array[Byte] = {
54+
// dbAppExecKeyMethod.invoke(this, fullId).asInstanceOf[Array[Byte]]
55+
// }
56+
57+
def setTransport(transport: ExternalUcxServerTransport): Unit = {
58+
ucxTransport = transport
59+
}
60+
/** Registers a new Executor with all the configuration we need to find its shuffle files. */
61+
override def registerExecutor(
62+
appId: String,
63+
execId: String,
64+
executorInfo: ExecutorShuffleInfo): Unit = {
65+
val fullId = new AppExecId(appId, execId)
66+
logInfo(s"Registered executor ${fullId} with ${executorInfo}")
67+
if (!knownManagers.contains(executorInfo.shuffleManager)) {
68+
throw new UnsupportedOperationException(
69+
"Unsupported shuffle manager of executor: " + executorInfo)
70+
}
71+
try {
72+
if (db != null) {
73+
val key = dbAppExecKey(fullId)
74+
val value = ucxMapper.writeValueAsString(executorInfo).getBytes(StandardCharsets.UTF_8)
75+
db.put(key, value)
76+
}
77+
executors.put(fullId, executorInfo)
78+
} catch {
79+
case e: Exception => logError("Error saving registered executors", e)
80+
}
81+
}
82+
83+
override def applicationRemoved(appId: String, cleanupLocalDirs: Boolean): Unit = {
84+
super.applicationRemoved(appId, cleanupLocalDirs)
85+
ucxTransport.applicationRemoved(appId)
86+
}
87+
override def executorRemoved(executorId: String, appId: String): Unit = {
88+
super.executorRemoved(executorId, appId)
89+
ucxTransport.executorRemoved(executorId, appId)
90+
}
91+
}
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
package org.apache.spark.network.shuffle
2+
3+
import java.io.File
4+
5+
import org.apache.spark.network.server.OneForOneStreamManager
6+
import org.apache.spark.network.util.TransportConf
7+
8+
import org.apache.spark.shuffle.utils.UcxLogging
9+
import org.apache.spark.shuffle.ucx.ExternalUcxServerTransport
10+
11+
class ExternalUcxShuffleBlockHandler(conf: TransportConf, registeredExecutorFile: File)
12+
extends ExternalShuffleBlockHandler(new OneForOneStreamManager(),
13+
new ExternalUcxShuffleBlockResolver(conf, registeredExecutorFile)) with UcxLogging {
14+
def ucxBlockManager(): ExternalUcxShuffleBlockResolver = {
15+
blockManager.asInstanceOf[ExternalUcxShuffleBlockResolver]
16+
}
17+
def setTransport(transport: ExternalUcxServerTransport): Unit = {
18+
ucxBlockManager.setTransport(transport)
19+
}
20+
}
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
package org.apache.spark.network.shuffle
2+
3+
import java.io.File
4+
5+
import org.apache.spark.network.server.OneForOneStreamManager
6+
import org.apache.spark.network.util.TransportConf
7+
8+
import org.apache.spark.shuffle.utils.UcxLogging
9+
import org.apache.spark.shuffle.ucx.ExternalUcxServerTransport
10+
11+
class ExternalUcxShuffleBlockHandler(conf: TransportConf, registeredExecutorFile: File)
12+
extends ExternalBlockHandler(new OneForOneStreamManager(),
13+
new ExternalUcxShuffleBlockResolver(conf, registeredExecutorFile)) with UcxLogging {
14+
def ucxBlockManager(): ExternalUcxShuffleBlockResolver = {
15+
blockManager.asInstanceOf[ExternalUcxShuffleBlockResolver]
16+
}
17+
def setTransport(transport: ExternalUcxServerTransport): Unit = {
18+
ucxBlockManager.setTransport(transport)
19+
}
20+
}

0 commit comments

Comments
 (0)